diff --git a/alembic/versions/25f4e2a1c9d8_add_workspace_git_sync_tables.py b/alembic/versions/25f4e2a1c9d8_add_workspace_git_sync_tables.py new file mode 100644 index 0000000000..2a685dfb77 --- /dev/null +++ b/alembic/versions/25f4e2a1c9d8_add_workspace_git_sync_tables.py @@ -0,0 +1,420 @@ +"""add workspace git sync tables + +Revision ID: 25f4e2a1c9d8 +Revises: 9b52f7f18a31 +Create Date: 2026-06-05 00:00:00.000000 + +""" + +import uuid +from collections.abc import Sequence +from datetime import datetime + +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +from alembic import op +from tracecat.db.tenant_rls import ( + disable_workspace_table_rls, + enable_workspace_table_rls, +) + +# revision identifiers, used by Alembic. +revision: str = "25f4e2a1c9d8" +down_revision: str | None = "9b52f7f18a31" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + +WORKSPACE_SYNC_TABLES = ( + "workspace_sync_state", + "workspace_sync_resource_mapping", + "workspace_sync_event", + "workspace_sync_changeset", + "workspace_sync_changeset_item", + "workspace_sync_materialization", +) + + +def _timestamps() -> list[sa.Column[datetime]]: + return [ + sa.Column( + "created_at", + sa.TIMESTAMP(timezone=True), + server_default=sa.text("now()"), + nullable=False, + ), + sa.Column( + "updated_at", + sa.TIMESTAMP(timezone=True), + server_default=sa.text("now()"), + nullable=False, + ), + ] + + +def _tenant_columns() -> list[sa.Column[uuid.UUID]]: + return [ + sa.Column("workspace_id", postgresql.UUID(as_uuid=True), nullable=False), + ] + + +def _record_columns() -> list[sa.Column[uuid.UUID] | sa.Column[int]]: + return [ + sa.Column("surrogate_id", sa.Integer(), sa.Identity(), nullable=False), + sa.Column("id", postgresql.UUID(as_uuid=True), nullable=False), + ] + + +def _tenant_fks(table_name: str) -> list[sa.Constraint]: + return [ + sa.ForeignKeyConstraint( + ["workspace_id"], + ["workspace.id"], + name=op.f(f"fk_{table_name}_workspace_id_workspace"), + ondelete="CASCADE", + ), + ] + + +def upgrade() -> None: + op.create_table( + "workspace_sync_state", + *_record_columns(), + *_tenant_columns(), + sa.Column( + "provider", sa.String(length=32), server_default="git", nullable=False + ), + sa.Column("repo_url", sa.String(), nullable=False), + sa.Column("target_ref", sa.String(), server_default="main", nullable=False), + sa.Column("base_commit_sha", sa.String(), nullable=True), + sa.Column("base_tree_sha", sa.String(), nullable=True), + sa.Column("base_spec_hash", sa.String(), nullable=True), + sa.Column("last_remote_commit_sha", sa.String(), nullable=True), + sa.Column("last_remote_tree_sha", sa.String(), nullable=True), + sa.Column( + "status", + sa.String(length=32), + server_default="never_synced", + nullable=False, + ), + sa.Column("last_direction", sa.String(length=16), nullable=True), + sa.Column( + "last_error", + postgresql.JSONB(astext_type=sa.Text()), + nullable=True, + ), + sa.Column("last_synced_at", sa.TIMESTAMP(timezone=True), nullable=True), + *_timestamps(), + *_tenant_fks("workspace_sync_state"), + sa.PrimaryKeyConstraint("surrogate_id", name=op.f("pk_workspace_sync_state")), + sa.UniqueConstraint("id", name=op.f("uq_workspace_sync_state_id")), + sa.UniqueConstraint( + "workspace_id", + "provider", + "repo_url", + "target_ref", + name="uq_workspace_sync_state_workspace_provider_repo_ref", + ), + ) + op.create_index( + op.f("ix_workspace_sync_state_id"), + "workspace_sync_state", + ["id"], + unique=True, + ) + op.create_index( + op.f("ix_workspace_sync_state_workspace_id"), + "workspace_sync_state", + ["workspace_id"], + unique=False, + ) + + op.create_table( + "workspace_sync_resource_mapping", + *_record_columns(), + *_tenant_columns(), + sa.Column( + "provider", sa.String(length=32), server_default="git", nullable=False + ), + sa.Column("resource_type", sa.String(length=64), nullable=False), + sa.Column("source_id", sa.String(), nullable=False), + sa.Column("source_path", sa.String(), nullable=True), + sa.Column("local_id", postgresql.UUID(as_uuid=True), nullable=False), + sa.Column("last_synced_commit_sha", sa.String(), nullable=True), + sa.Column("last_synced_spec_hash", sa.String(), nullable=True), + sa.Column( + "sync_status", + sa.String(length=32), + server_default="untracked", + nullable=False, + ), + *_timestamps(), + *_tenant_fks("workspace_sync_resource_mapping"), + sa.PrimaryKeyConstraint( + "surrogate_id", name=op.f("pk_workspace_sync_resource_mapping") + ), + sa.UniqueConstraint("id", name=op.f("uq_workspace_sync_resource_mapping_id")), + sa.UniqueConstraint( + "workspace_id", + "provider", + "resource_type", + "source_id", + name="uq_workspace_sync_mapping_source", + ), + sa.UniqueConstraint( + "workspace_id", + "provider", + "resource_type", + "local_id", + name="uq_workspace_sync_mapping_local", + ), + ) + op.create_index( + op.f("ix_workspace_sync_resource_mapping_id"), + "workspace_sync_resource_mapping", + ["id"], + unique=True, + ) + op.create_index( + op.f("ix_workspace_sync_resource_mapping_workspace_id"), + "workspace_sync_resource_mapping", + ["workspace_id"], + unique=False, + ) + + op.create_table( + "workspace_sync_event", + *_record_columns(), + *_tenant_columns(), + sa.Column( + "provider", sa.String(length=32), server_default="git", nullable=False + ), + sa.Column("resource_type", sa.String(length=64), nullable=False), + sa.Column("source_id", sa.String(), nullable=True), + sa.Column("local_id", postgresql.UUID(as_uuid=True), nullable=True), + sa.Column("operation", sa.String(length=32), nullable=False), + sa.Column("actor_id", postgresql.UUID(as_uuid=True), nullable=True), + sa.Column("base_commit_sha", sa.String(), nullable=True), + sa.Column("before_spec_hash", sa.String(), nullable=True), + sa.Column("after_spec_hash", sa.String(), nullable=True), + sa.Column( + "affected_paths", + postgresql.JSONB(astext_type=sa.Text()), + server_default=sa.text("'[]'::jsonb"), + nullable=False, + ), + sa.Column( + "metadata", + postgresql.JSONB(astext_type=sa.Text()), + server_default=sa.text("'{}'::jsonb"), + nullable=False, + ), + sa.Column("superseded_by", postgresql.UUID(as_uuid=True), nullable=True), + sa.Column("changeset_id", postgresql.UUID(as_uuid=True), nullable=True), + *_timestamps(), + *_tenant_fks("workspace_sync_event"), + sa.PrimaryKeyConstraint("surrogate_id", name=op.f("pk_workspace_sync_event")), + sa.UniqueConstraint("id", name=op.f("uq_workspace_sync_event_id")), + ) + op.create_index( + op.f("ix_workspace_sync_event_id"), + "workspace_sync_event", + ["id"], + unique=True, + ) + op.create_index( + op.f("ix_workspace_sync_event_workspace_id"), + "workspace_sync_event", + ["workspace_id"], + unique=False, + ) + + op.create_table( + "workspace_sync_changeset", + *_record_columns(), + *_tenant_columns(), + sa.Column( + "provider", sa.String(length=32), server_default="git", nullable=False + ), + sa.Column("title", sa.String(), nullable=False), + sa.Column("description", sa.Text(), nullable=True), + sa.Column("base_commit_sha", sa.String(), nullable=True), + sa.Column("base_spec_hash", sa.String(), nullable=True), + sa.Column( + "selected_resources", + postgresql.JSONB(astext_type=sa.Text()), + server_default=sa.text("'[]'::jsonb"), + nullable=False, + ), + sa.Column( + "selected_paths", + postgresql.JSONB(astext_type=sa.Text()), + server_default=sa.text("'[]'::jsonb"), + nullable=False, + ), + sa.Column( + "validation_status", + sa.String(length=32), + server_default="pending", + nullable=False, + ), + sa.Column( + "validation_result", + postgresql.JSONB(astext_type=sa.Text()), + server_default=sa.text("'{}'::jsonb"), + nullable=False, + ), + sa.Column( + "status", sa.String(length=32), server_default="open", nullable=False + ), + sa.Column("created_by", postgresql.UUID(as_uuid=True), nullable=True), + *_timestamps(), + *_tenant_fks("workspace_sync_changeset"), + sa.PrimaryKeyConstraint( + "surrogate_id", name=op.f("pk_workspace_sync_changeset") + ), + sa.UniqueConstraint("id", name=op.f("uq_workspace_sync_changeset_id")), + ) + op.create_index( + op.f("ix_workspace_sync_changeset_id"), + "workspace_sync_changeset", + ["id"], + unique=True, + ) + op.create_index( + op.f("ix_workspace_sync_changeset_workspace_id"), + "workspace_sync_changeset", + ["workspace_id"], + unique=False, + ) + + op.create_table( + "workspace_sync_changeset_item", + *_record_columns(), + *_tenant_columns(), + sa.Column("changeset_id", postgresql.UUID(as_uuid=True), nullable=False), + sa.Column("resource_type", sa.String(length=64), nullable=False), + sa.Column("source_id", sa.String(), nullable=False), + sa.Column("source_path", sa.String(), nullable=True), + sa.Column("local_id", postgresql.UUID(as_uuid=True), nullable=True), + sa.Column("operation", sa.String(length=32), nullable=False), + sa.Column("spec_hash", sa.String(), nullable=True), + sa.Column( + "dependencies", + postgresql.JSONB(astext_type=sa.Text()), + server_default=sa.text("'[]'::jsonb"), + nullable=False, + ), + *_timestamps(), + *_tenant_fks("workspace_sync_changeset_item"), + sa.ForeignKeyConstraint( + ["changeset_id"], + ["workspace_sync_changeset.id"], + name=op.f( + "fk_workspace_sync_changeset_item_changeset_id_workspace_sync_changeset" + ), + ondelete="CASCADE", + ), + sa.PrimaryKeyConstraint( + "surrogate_id", name=op.f("pk_workspace_sync_changeset_item") + ), + sa.UniqueConstraint("id", name=op.f("uq_workspace_sync_changeset_item_id")), + sa.UniqueConstraint( + "changeset_id", + "resource_type", + "source_id", + name="uq_workspace_sync_changeset_item_resource", + ), + ) + op.create_index( + op.f("ix_workspace_sync_changeset_item_id"), + "workspace_sync_changeset_item", + ["id"], + unique=True, + ) + op.create_index( + op.f("ix_workspace_sync_changeset_item_workspace_id"), + "workspace_sync_changeset_item", + ["workspace_id"], + unique=False, + ) + op.create_index( + op.f("ix_workspace_sync_changeset_item_changeset_id"), + "workspace_sync_changeset_item", + ["changeset_id"], + unique=False, + ) + + op.create_table( + "workspace_sync_materialization", + *_record_columns(), + *_tenant_columns(), + sa.Column("changeset_id", postgresql.UUID(as_uuid=True), nullable=False), + sa.Column( + "provider", sa.String(length=32), server_default="git", nullable=False + ), + sa.Column("branch", sa.String(), nullable=False), + sa.Column("base_ref", sa.String(), nullable=True), + sa.Column("pr_number", sa.Integer(), nullable=True), + sa.Column("pr_url", sa.String(), nullable=True), + sa.Column( + "commit_shas", + postgresql.JSONB(astext_type=sa.Text()), + server_default=sa.text("'[]'::jsonb"), + nullable=False, + ), + sa.Column( + "status", + sa.String(length=32), + server_default="pending", + nullable=False, + ), + sa.Column("error", postgresql.JSONB(astext_type=sa.Text()), nullable=True), + *_timestamps(), + *_tenant_fks("workspace_sync_materialization"), + sa.ForeignKeyConstraint( + ["changeset_id"], + ["workspace_sync_changeset.id"], + name=op.f( + "fk_workspace_sync_materialization_changeset_id_workspace_sync_changeset" + ), + ondelete="CASCADE", + ), + sa.PrimaryKeyConstraint( + "surrogate_id", name=op.f("pk_workspace_sync_materialization") + ), + sa.UniqueConstraint("id", name=op.f("uq_workspace_sync_materialization_id")), + ) + op.create_index( + op.f("ix_workspace_sync_materialization_id"), + "workspace_sync_materialization", + ["id"], + unique=True, + ) + op.create_index( + op.f("ix_workspace_sync_materialization_workspace_id"), + "workspace_sync_materialization", + ["workspace_id"], + unique=False, + ) + op.create_index( + op.f("ix_workspace_sync_materialization_changeset_id"), + "workspace_sync_materialization", + ["changeset_id"], + unique=False, + ) + + for table in WORKSPACE_SYNC_TABLES: + op.execute(enable_workspace_table_rls(table)) + + +def downgrade() -> None: + for table in reversed(WORKSPACE_SYNC_TABLES): + op.execute(disable_workspace_table_rls(table)) + + op.drop_table("workspace_sync_materialization") + op.drop_table("workspace_sync_changeset_item") + op.drop_table("workspace_sync_changeset") + op.drop_table("workspace_sync_event") + op.drop_table("workspace_sync_resource_mapping") + op.drop_table("workspace_sync_state") diff --git a/alembic/versions/b3f7a92d1c4e_add_rendered_files_to_workspace_sync_changesets.py b/alembic/versions/b3f7a92d1c4e_add_rendered_files_to_workspace_sync_changesets.py new file mode 100644 index 0000000000..c8540af712 --- /dev/null +++ b/alembic/versions/b3f7a92d1c4e_add_rendered_files_to_workspace_sync_changesets.py @@ -0,0 +1,36 @@ +"""add rendered files to workspace sync changesets + +Revision ID: b3f7a92d1c4e +Revises: 25f4e2a1c9d8 +Create Date: 2026-06-07 00:00:00.000000 + +""" + +from collections.abc import Sequence + +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "b3f7a92d1c4e" +down_revision: str | None = "25f4e2a1c9d8" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +def upgrade() -> None: + op.add_column( + "workspace_sync_changeset", + sa.Column( + "rendered_files", + postgresql.JSONB(astext_type=sa.Text()), + server_default=sa.text("'{}'::jsonb"), + nullable=False, + ), + ) + + +def downgrade() -> None: + op.drop_column("workspace_sync_changeset", "rendered_files") diff --git a/frontend/src/client/schemas.gen.ts b/frontend/src/client/schemas.gen.ts index e6b1e57590..a6cff45fca 100644 --- a/frontend/src/client/schemas.gen.ts +++ b/frontend/src/client/schemas.gen.ts @@ -9161,6 +9161,155 @@ export const $CaseViewedEventRead = { description: "Event for when a case is viewed.", } as const +export const $ChangeSetCreate = { + properties: { + title: { + type: "string", + minLength: 1, + title: "Title", + }, + description: { + anyOf: [ + { + type: "string", + }, + { + type: "null", + }, + ], + title: "Description", + }, + resources: { + items: { + $ref: "#/components/schemas/ResourceRef", + }, + type: "array", + title: "Resources", + }, + }, + type: "object", + required: ["title", "resources"], + title: "ChangeSetCreate", +} as const + +export const $ChangeSetExport = { + properties: { + message: { + type: "string", + title: "Message", + }, + branch: { + type: "string", + title: "Branch", + }, + create_pr: { + type: "boolean", + title: "Create Pr", + default: false, + }, + pr_base_branch: { + anyOf: [ + { + type: "string", + }, + { + type: "null", + }, + ], + title: "Pr Base Branch", + }, + }, + type: "object", + required: ["message", "branch"], + title: "ChangeSetExport", +} as const + +export const $ChangeSetRead = { + properties: { + id: { + type: "string", + format: "uuid", + title: "Id", + }, + title: { + type: "string", + title: "Title", + }, + description: { + anyOf: [ + { + type: "string", + }, + { + type: "null", + }, + ], + title: "Description", + }, + base_commit_sha: { + anyOf: [ + { + type: "string", + }, + { + type: "null", + }, + ], + title: "Base Commit Sha", + }, + base_spec_hash: { + anyOf: [ + { + type: "string", + }, + { + type: "null", + }, + ], + title: "Base Spec Hash", + }, + selected_resources: { + items: { + additionalProperties: true, + type: "object", + }, + type: "array", + title: "Selected Resources", + }, + selected_paths: { + items: { + type: "string", + }, + type: "array", + title: "Selected Paths", + }, + validation_status: { + type: "string", + title: "Validation Status", + }, + validation_result: { + additionalProperties: true, + type: "object", + title: "Validation Result", + }, + status: { + type: "string", + title: "Status", + }, + }, + type: "object", + required: [ + "id", + "title", + "selected_resources", + "selected_paths", + "validation_status", + "validation_result", + "status", + ], + title: "ChangeSetRead", +} as const + export const $ChannelType = { type: "string", enum: ["slack"], @@ -10173,6 +10322,68 @@ export const $CommentUpdatedEventRead = { description: "Event for when a top-level comment is updated.", } as const +export const $CommitInfo = { + properties: { + status: { + $ref: "#/components/schemas/PushStatus", + }, + sha: { + anyOf: [ + { + type: "string", + }, + { + type: "null", + }, + ], + title: "Sha", + }, + ref: { + type: "string", + title: "Ref", + }, + base_ref: { + type: "string", + title: "Base Ref", + }, + pr_url: { + anyOf: [ + { + type: "string", + }, + { + type: "null", + }, + ], + title: "Pr Url", + }, + pr_number: { + anyOf: [ + { + type: "integer", + }, + { + type: "null", + }, + ], + title: "Pr Number", + }, + pr_reused: { + type: "boolean", + title: "Pr Reused", + default: false, + }, + message: { + type: "string", + title: "Message", + default: "", + }, + }, + type: "object", + required: ["status", "sha", "ref", "base_ref"], + title: "CommitInfo", +} as const + export const $ContinueRunRequest = { properties: { kind: { @@ -18562,6 +18773,13 @@ export const $PullResult = { title: "PullResult", } as const +export const $PushStatus = { + type: "string", + enum: ["committed", "no_op"], + title: "PushStatus", + description: "Status of a push/commit operation.", +} as const + export const $RateLimitEvent = { properties: { rate_limit_info: { @@ -19896,6 +20114,45 @@ export const $ResolvedAttachedSubagentRef = { "Persisted subagent ref with immutable preset/version identifiers.", } as const +export const $ResourceRef = { + properties: { + resource_type: { + type: "string", + title: "Resource Type", + }, + source_id: { + type: "string", + title: "Source Id", + }, + source_path: { + anyOf: [ + { + type: "string", + }, + { + type: "null", + }, + ], + title: "Source Path", + }, + local_id: { + anyOf: [ + { + type: "string", + format: "uuid", + }, + { + type: "null", + }, + ], + title: "Local Id", + }, + }, + type: "object", + required: ["resource_type", "source_id"], + title: "ResourceRef", +} as const + export const $ResponseInteraction = { properties: { type: { @@ -23397,6 +23654,26 @@ export const $StringListFieldChange = { description: "List diff for preset version fields.", } as const +export const $SyncOperation = { + type: "string", + enum: ["create", "update", "delete", "archive", "disable"], + title: "SyncOperation", +} as const + +export const $SyncStateStatus = { + type: "string", + enum: [ + "never_synced", + "clean", + "local_dirty", + "remote_ahead", + "diverged", + "conflicted", + "error", + ], + title: "SyncStateStatus", +} as const + export const $SyntaxToken = { properties: { type: { @@ -31276,6 +31553,218 @@ export const $WorkspaceSettingsUpdate = { title: "WorkspaceSettingsUpdate", } as const +export const $WorkspaceSyncExportResult = { + properties: { + changeset_id: { + type: "string", + format: "uuid", + title: "Changeset Id", + }, + commit: { + $ref: "#/components/schemas/CommitInfo", + }, + }, + type: "object", + required: ["changeset_id", "commit"], + title: "WorkspaceSyncExportResult", +} as const + +export const $WorkspaceSyncPendingChange = { + properties: { + resource_type: { + type: "string", + title: "Resource Type", + }, + source_id: { + type: "string", + title: "Source Id", + }, + source_path: { + type: "string", + title: "Source Path", + }, + local_id: { + anyOf: [ + { + type: "string", + format: "uuid", + }, + { + type: "null", + }, + ], + title: "Local Id", + }, + operation: { + $ref: "#/components/schemas/SyncOperation", + }, + title: { + anyOf: [ + { + type: "string", + }, + { + type: "null", + }, + ], + title: "Title", + }, + alias: { + anyOf: [ + { + type: "string", + }, + { + type: "null", + }, + ], + title: "Alias", + }, + before_spec_hash: { + anyOf: [ + { + type: "string", + }, + { + type: "null", + }, + ], + title: "Before Spec Hash", + }, + after_spec_hash: { + anyOf: [ + { + type: "string", + }, + { + type: "null", + }, + ], + title: "After Spec Hash", + }, + exportable: { + type: "boolean", + title: "Exportable", + default: true, + }, + }, + type: "object", + required: ["resource_type", "source_id", "source_path", "operation"], + title: "WorkspaceSyncPendingChange", +} as const + +export const $WorkspaceSyncPendingChanges = { + properties: { + base_spec_hash: { + anyOf: [ + { + type: "string", + }, + { + type: "null", + }, + ], + title: "Base Spec Hash", + }, + local_spec_hash: { + type: "string", + title: "Local Spec Hash", + }, + changes: { + items: { + $ref: "#/components/schemas/WorkspaceSyncPendingChange", + }, + type: "array", + title: "Changes", + }, + }, + type: "object", + required: ["local_spec_hash"], + title: "WorkspaceSyncPendingChanges", +} as const + +export const $WorkspaceSyncStatus = { + properties: { + status: { + $ref: "#/components/schemas/SyncStateStatus", + }, + base_spec_hash: { + anyOf: [ + { + type: "string", + }, + { + type: "null", + }, + ], + title: "Base Spec Hash", + }, + local_spec_hash: { + type: "string", + title: "Local Spec Hash", + }, + remote_spec_hash: { + anyOf: [ + { + type: "string", + }, + { + type: "null", + }, + ], + title: "Remote Spec Hash", + }, + base_commit_sha: { + anyOf: [ + { + type: "string", + }, + { + type: "null", + }, + ], + title: "Base Commit Sha", + }, + remote_commit_sha: { + anyOf: [ + { + type: "string", + }, + { + type: "null", + }, + ], + title: "Remote Commit Sha", + }, + target_ref: { + anyOf: [ + { + type: "string", + }, + { + type: "null", + }, + ], + title: "Target Ref", + }, + pending_change_count: { + type: "integer", + title: "Pending Change Count", + default: 0, + }, + diagnostics: { + items: { + $ref: "#/components/schemas/PullDiagnostic", + }, + type: "array", + title: "Diagnostics", + }, + }, + type: "object", + required: ["status", "base_spec_hash", "local_spec_hash"], + title: "WorkspaceSyncStatus", +} as const + export const $WorkspaceUpdate = { properties: { name: { diff --git a/frontend/src/client/services.gen.ts b/frontend/src/client/services.gen.ts index 3ef8be3884..8333d883ff 100644 --- a/frontend/src/client/services.gen.ts +++ b/frontend/src/client/services.gen.ts @@ -838,14 +838,22 @@ import type { WorkflowsCreateWorkflowDefinitionData, WorkflowsCreateWorkflowDefinitionResponse, WorkflowsCreateWorkflowResponse, + WorkflowsCreateWorkspaceSyncChangesetData, + WorkflowsCreateWorkspaceSyncChangesetResponse, WorkflowsDeleteWorkflowData, WorkflowsDeleteWorkflowResponse, WorkflowsExportWorkflowData, WorkflowsExportWorkflowResponse, + WorkflowsExportWorkspaceSyncChangesetData, + WorkflowsExportWorkspaceSyncChangesetResponse, WorkflowsGetWorkflowData, WorkflowsGetWorkflowDefinitionData, WorkflowsGetWorkflowDefinitionResponse, WorkflowsGetWorkflowResponse, + WorkflowsGetWorkspaceSyncChangesetData, + WorkflowsGetWorkspaceSyncChangesetResponse, + WorkflowsGetWorkspaceSyncStatusData, + WorkflowsGetWorkspaceSyncStatusResponse, WorkflowsListTagsData, WorkflowsListTagsResponse, WorkflowsListWorkflowBranchesData, @@ -856,6 +864,10 @@ import type { WorkflowsListWorkflowDefinitionsResponse, WorkflowsListWorkflowsData, WorkflowsListWorkflowsResponse, + WorkflowsListWorkspaceSyncChangesetsData, + WorkflowsListWorkspaceSyncChangesetsResponse, + WorkflowsListWorkspaceSyncPendingChangesData, + WorkflowsListWorkspaceSyncPendingChangesResponse, WorkflowsMoveWorkflowToFolderData, WorkflowsMoveWorkflowToFolderResponse, WorkflowsPublishWorkflowData, @@ -3281,6 +3293,158 @@ export const workflowsListWorkflowBranches = ( }) } +/** + * Get Workspace Sync Status + * Get workspace-level Git sync status for the configured repository. + * @param data The data for the request. + * @param data.workspaceId + * @returns WorkspaceSyncStatus Successful Response + * @throws ApiError + */ +export const workflowsGetWorkspaceSyncStatus = ( + data: WorkflowsGetWorkspaceSyncStatusData +): CancelablePromise => { + return __request(OpenAPI, { + method: "GET", + url: "/workspaces/{workspace_id}/workflows/sync/status", + path: { + workspace_id: data.workspaceId, + }, + errors: { + 422: "Validation Error", + }, + }) +} + +/** + * List Workspace Sync Pending Changes + * List local syncable workspace changes pending Git export. + * @param data The data for the request. + * @param data.workspaceId + * @returns WorkspaceSyncPendingChanges Successful Response + * @throws ApiError + */ +export const workflowsListWorkspaceSyncPendingChanges = ( + data: WorkflowsListWorkspaceSyncPendingChangesData +): CancelablePromise => { + return __request(OpenAPI, { + method: "GET", + url: "/workspaces/{workspace_id}/workflows/sync/pending", + path: { + workspace_id: data.workspaceId, + }, + errors: { + 422: "Validation Error", + }, + }) +} + +/** + * List Workspace Sync Changesets + * List workspace sync ChangeSets. + * @param data The data for the request. + * @param data.workspaceId + * @param data.limit + * @returns ChangeSetRead Successful Response + * @throws ApiError + */ +export const workflowsListWorkspaceSyncChangesets = ( + data: WorkflowsListWorkspaceSyncChangesetsData +): CancelablePromise => { + return __request(OpenAPI, { + method: "GET", + url: "/workspaces/{workspace_id}/workflows/sync/changesets", + path: { + workspace_id: data.workspaceId, + }, + query: { + limit: data.limit, + }, + errors: { + 422: "Validation Error", + }, + }) +} + +/** + * Create Workspace Sync Changeset + * Create a workspace sync ChangeSet from selected pending resources. + * @param data The data for the request. + * @param data.workspaceId + * @param data.requestBody + * @returns ChangeSetRead Successful Response + * @throws ApiError + */ +export const workflowsCreateWorkspaceSyncChangeset = ( + data: WorkflowsCreateWorkspaceSyncChangesetData +): CancelablePromise => { + return __request(OpenAPI, { + method: "POST", + url: "/workspaces/{workspace_id}/workflows/sync/changesets", + path: { + workspace_id: data.workspaceId, + }, + body: data.requestBody, + mediaType: "application/json", + errors: { + 422: "Validation Error", + }, + }) +} + +/** + * Get Workspace Sync Changeset + * Get a workspace sync ChangeSet. + * @param data The data for the request. + * @param data.changesetId + * @param data.workspaceId + * @returns ChangeSetRead Successful Response + * @throws ApiError + */ +export const workflowsGetWorkspaceSyncChangeset = ( + data: WorkflowsGetWorkspaceSyncChangesetData +): CancelablePromise => { + return __request(OpenAPI, { + method: "GET", + url: "/workspaces/{workspace_id}/workflows/sync/changesets/{changeset_id}", + path: { + changeset_id: data.changesetId, + workspace_id: data.workspaceId, + }, + errors: { + 422: "Validation Error", + }, + }) +} + +/** + * Export Workspace Sync Changeset + * Export a workspace sync ChangeSet to a Git branch and optional PR. + * @param data The data for the request. + * @param data.changesetId + * @param data.workspaceId + * @param data.requestBody + * @returns WorkspaceSyncExportResult Successful Response + * @throws ApiError + */ +export const workflowsExportWorkspaceSyncChangeset = ( + data: WorkflowsExportWorkspaceSyncChangesetData +): CancelablePromise => { + return __request(OpenAPI, { + method: "POST", + url: "/workspaces/{workspace_id}/workflows/sync/changesets/{changeset_id}/export", + path: { + changeset_id: data.changesetId, + workspace_id: data.workspaceId, + }, + body: data.requestBody, + mediaType: "application/json", + errors: { + 422: "Validation Error", + }, + }) +} + /** * Pull Workflows * Pull workflows from Git repository at specific commit. diff --git a/frontend/src/client/types.gen.ts b/frontend/src/client/types.gen.ts index 41d41fc327..d43e92efd4 100644 --- a/frontend/src/client/types.gen.ts +++ b/frontend/src/client/types.gen.ts @@ -2415,6 +2415,36 @@ export type CaseViewedEventRead = { created_at: string } +export type ChangeSetCreate = { + title: string + description?: string | null + resources: Array +} + +export type ChangeSetExport = { + message: string + branch: string + create_pr?: boolean + pr_base_branch?: string | null +} + +export type ChangeSetRead = { + id: string + title: string + description?: string | null + base_commit_sha?: string | null + base_spec_hash?: string | null + selected_resources: Array<{ + [key: string]: unknown + }> + selected_paths: Array + validation_status: string + validation_result: { + [key: string]: unknown + } + status: string +} + /** * Supported external channel types. */ @@ -2836,6 +2866,17 @@ export type CommentUpdatedEventRead = { created_at: string } +export type CommitInfo = { + status: PushStatus + sha: string | null + ref: string + base_ref: string + pr_url?: string | null + pr_number?: number | null + pr_reused?: boolean + message?: string +} + /** * Payload to continue a CE run after collecting approvals. */ @@ -5605,6 +5646,11 @@ export type PullResult = { message: string } +/** + * Status of a push/commit operation. + */ +export type PushStatus = "committed" | "no_op" + export type RateLimitEvent = { rate_limit_info: RateLimitInfo uuid: string @@ -6047,6 +6093,13 @@ export type ResolvedAttachedSubagentRef = { preset_version_id: string } +export type ResourceRef = { + resource_type: string + source_id: string + source_path?: string | null + local_id?: string | null +} + /** * Configuration for a response interaction. */ @@ -7030,6 +7083,22 @@ export type StringListFieldChange = { removed?: Array } +export type SyncOperation = + | "create" + | "update" + | "delete" + | "archive" + | "disable" + +export type SyncStateStatus = + | "never_synced" + | "clean" + | "local_dirty" + | "remote_ahead" + | "diverged" + | "conflicted" + | "error" + export type SyntaxToken = { type: string value: string @@ -9253,6 +9322,42 @@ export type WorkspaceSettingsUpdate = { validate_attachment_magic_number?: boolean | null } +export type WorkspaceSyncExportResult = { + changeset_id: string + commit: CommitInfo +} + +export type WorkspaceSyncPendingChange = { + resource_type: string + source_id: string + source_path: string + local_id?: string | null + operation: SyncOperation + title?: string | null + alias?: string | null + before_spec_hash?: string | null + after_spec_hash?: string | null + exportable?: boolean +} + +export type WorkspaceSyncPendingChanges = { + base_spec_hash?: string | null + local_spec_hash: string + changes?: Array +} + +export type WorkspaceSyncStatus = { + status: SyncStateStatus + base_spec_hash: string | null + local_spec_hash: string + remote_spec_hash?: string | null + base_commit_sha?: string | null + remote_commit_sha?: string | null + target_ref?: string | null + pending_change_count?: number + diagnostics?: Array +} + export type WorkspaceUpdate = { name?: string | null settings?: WorkspaceSettingsUpdate | null @@ -10121,6 +10226,49 @@ export type WorkflowsListWorkflowBranchesData = { export type WorkflowsListWorkflowBranchesResponse = Array +export type WorkflowsGetWorkspaceSyncStatusData = { + workspaceId: string +} + +export type WorkflowsGetWorkspaceSyncStatusResponse = WorkspaceSyncStatus + +export type WorkflowsListWorkspaceSyncPendingChangesData = { + workspaceId: string +} + +export type WorkflowsListWorkspaceSyncPendingChangesResponse = + WorkspaceSyncPendingChanges + +export type WorkflowsListWorkspaceSyncChangesetsData = { + limit?: number + workspaceId: string +} + +export type WorkflowsListWorkspaceSyncChangesetsResponse = Array + +export type WorkflowsCreateWorkspaceSyncChangesetData = { + requestBody: ChangeSetCreate + workspaceId: string +} + +export type WorkflowsCreateWorkspaceSyncChangesetResponse = ChangeSetRead + +export type WorkflowsGetWorkspaceSyncChangesetData = { + changesetId: string + workspaceId: string +} + +export type WorkflowsGetWorkspaceSyncChangesetResponse = ChangeSetRead + +export type WorkflowsExportWorkspaceSyncChangesetData = { + changesetId: string + requestBody: ChangeSetExport + workspaceId: string +} + +export type WorkflowsExportWorkspaceSyncChangesetResponse = + WorkspaceSyncExportResult + export type WorkflowsPullWorkflowsData = { requestBody: WorkflowSyncPullRequest workspaceId: string @@ -14294,6 +14442,94 @@ export type $OpenApiTs = { } } } + "/workspaces/{workspace_id}/workflows/sync/status": { + get: { + req: WorkflowsGetWorkspaceSyncStatusData + res: { + /** + * Successful Response + */ + 200: WorkspaceSyncStatus + /** + * Validation Error + */ + 422: HTTPValidationError + } + } + } + "/workspaces/{workspace_id}/workflows/sync/pending": { + get: { + req: WorkflowsListWorkspaceSyncPendingChangesData + res: { + /** + * Successful Response + */ + 200: WorkspaceSyncPendingChanges + /** + * Validation Error + */ + 422: HTTPValidationError + } + } + } + "/workspaces/{workspace_id}/workflows/sync/changesets": { + get: { + req: WorkflowsListWorkspaceSyncChangesetsData + res: { + /** + * Successful Response + */ + 200: Array + /** + * Validation Error + */ + 422: HTTPValidationError + } + } + post: { + req: WorkflowsCreateWorkspaceSyncChangesetData + res: { + /** + * Successful Response + */ + 201: ChangeSetRead + /** + * Validation Error + */ + 422: HTTPValidationError + } + } + } + "/workspaces/{workspace_id}/workflows/sync/changesets/{changeset_id}": { + get: { + req: WorkflowsGetWorkspaceSyncChangesetData + res: { + /** + * Successful Response + */ + 200: ChangeSetRead + /** + * Validation Error + */ + 422: HTTPValidationError + } + } + } + "/workspaces/{workspace_id}/workflows/sync/changesets/{changeset_id}/export": { + post: { + req: WorkflowsExportWorkspaceSyncChangesetData + res: { + /** + * Successful Response + */ + 200: WorkspaceSyncExportResult + /** + * Validation Error + */ + 422: HTTPValidationError + } + } + } "/workspaces/{workspace_id}/workflows/sync/pull": { post: { req: WorkflowsPullWorkflowsData diff --git a/frontend/src/components/organization/workspace-sync-staging.tsx b/frontend/src/components/organization/workspace-sync-staging.tsx new file mode 100644 index 0000000000..844285b2b5 --- /dev/null +++ b/frontend/src/components/organization/workspace-sync-staging.tsx @@ -0,0 +1,429 @@ +"use client" + +import { + GitBranchIcon, + GitPullRequestArrowIcon, + RefreshCwIcon, + UploadCloudIcon, +} from "lucide-react" +import * as React from "react" + +import type { ResourceRef, SyncOperation, SyncStateStatus } from "@/client" +import { CenteredSpinner } from "@/components/loading/spinner" +import { Badge } from "@/components/ui/badge" +import { Button } from "@/components/ui/button" +import { Checkbox } from "@/components/ui/checkbox" +import { Input } from "@/components/ui/input" +import { Label } from "@/components/ui/label" +import { + Table, + TableBody, + TableCell, + TableHead, + TableHeader, + TableRow, +} from "@/components/ui/table" +import { useToast } from "@/components/ui/use-toast" +import { + useWorkspaceSyncChangesetActions, + useWorkspaceSyncChangesets, + useWorkspaceSyncPendingChanges, + useWorkspaceSyncStatus, +} from "@/hooks/use-workspace-sync" +import { cn } from "@/lib/utils" +import { useOptionalWorkspaceId } from "@/providers/workspace-id" + +interface WorkspaceSyncStagingProps { + workspaceId?: string +} + +const statusLabels: Record = { + never_synced: "Never synced", + clean: "Clean", + local_dirty: "Local dirty", + remote_ahead: "Remote ahead", + diverged: "Diverged", + conflicted: "Conflicted", + error: "Error", +} + +const operationLabels: Record = { + create: "Create", + update: "Update", + delete: "Delete", + archive: "Archive", + disable: "Disable", +} + +function statusClass(status: SyncStateStatus | undefined) { + switch (status) { + case "clean": + return "border-emerald-200 bg-emerald-50 text-emerald-700" + case "local_dirty": + case "remote_ahead": + return "border-amber-200 bg-amber-50 text-amber-700" + case "diverged": + case "conflicted": + case "error": + return "border-rose-200 bg-rose-50 text-rose-700" + default: + return "border-muted-foreground/20 text-muted-foreground" + } +} + +function operationClass(operation: SyncOperation) { + switch (operation) { + case "create": + return "border-emerald-200 bg-emerald-50 text-emerald-700" + case "update": + return "border-blue-200 bg-blue-50 text-blue-700" + default: + return "border-muted-foreground/20 text-muted-foreground" + } +} + +function errorMessage(error: unknown) { + if (error instanceof Error) { + return error.message + } + return "Request failed" +} + +/** + * Renders local workspace Git changes and exports selected resources as a changeset. + */ +export function WorkspaceSyncStaging({ + workspaceId: workspaceIdProp, +}: WorkspaceSyncStagingProps) { + const contextWorkspaceId = useOptionalWorkspaceId() + const workspaceId = workspaceIdProp ?? contextWorkspaceId + const { toast } = useToast() + const statusQuery = useWorkspaceSyncStatus(workspaceId) + const pendingQuery = useWorkspaceSyncPendingChanges(workspaceId) + const changesetsQuery = useWorkspaceSyncChangesets(workspaceId, { limit: 5 }) + const { createChangeset, exportChangeset } = + useWorkspaceSyncChangesetActions(workspaceId) + + const changes = React.useMemo( + () => pendingQuery.data?.changes ?? [], + [pendingQuery.data?.changes] + ) + const selectionInitializedRef = React.useRef(false) + const [selectedSourceIds, setSelectedSourceIds] = React.useState>( + new Set() + ) + const [branch, setBranch] = React.useState("") + const [message, setMessage] = React.useState("Export workspace sync changes") + const [createPr, setCreatePr] = React.useState(true) + + React.useEffect(() => { + if (!pendingQuery.data) { + return + } + const sourceIds = new Set(changes.map((change) => change.source_id)) + setSelectedSourceIds((current) => { + if (sourceIds.size === 0) { + return sourceIds + } + if (!selectionInitializedRef.current) { + selectionInitializedRef.current = true + return sourceIds + } + return new Set([...current].filter((sourceId) => sourceIds.has(sourceId))) + }) + }, [changes, pendingQuery.data]) + + React.useEffect(() => { + if (branch || changes.length === 0) { + return + } + setBranch(`sync/${changes[0].source_id}`) + }, [branch, changes]) + + if (!workspaceId) { + return ( +
+ Workspace context unavailable. +
+ ) + } + + const selectedChanges = changes.filter((change) => + selectedSourceIds.has(change.source_id) + ) + const isBusy = createChangeset.isPending || exportChangeset.isPending + const canExport = + selectedChanges.length > 0 && branch.trim().length > 0 && !isBusy + + async function handleExport() { + if (!workspaceId || !canExport) { + return + } + + try { + const resources: ResourceRef[] = selectedChanges.map((change) => ({ + resource_type: change.resource_type, + source_id: change.source_id, + source_path: change.source_path, + local_id: change.local_id, + })) + const title = message.trim() || "Export workspace sync changes" + const changeset = await createChangeset.mutateAsync({ + title, + resources, + }) + const result = await exportChangeset.mutateAsync({ + changesetId: changeset.id, + requestBody: { + message: title, + branch: branch.trim(), + create_pr: createPr, + }, + }) + + toast({ + title: result.commit.pr_url ? "Pull request ready" : "Changes exported", + description: + result.commit.pr_url ?? result.commit.sha ?? result.commit.message, + }) + } catch (error) { + toast({ + title: "Export failed", + description: errorMessage(error), + variant: "destructive", + }) + } + } + + return ( +
+
+
+
+

Workspace sync

+

+ {statusQuery.data?.target_ref + ? `Tracking ${statusQuery.data.target_ref}` + : "No tracked ref"} +

+
+
+ + {statusQuery.data?.status + ? statusLabels[statusQuery.data.status] + : "Loading"} + + +
+
+ + {statusQuery.isLoading || pendingQuery.isLoading ? ( +
+ +
+ ) : statusQuery.error || pendingQuery.error ? ( +
+ {errorMessage(statusQuery.error ?? pendingQuery.error)} +
+ ) : ( + <> +
+
+ + {statusQuery.data?.pending_change_count ?? changes.length} + + Pending +
+
+ + {statusQuery.data?.base_commit_sha ?? "None"} + + Base commit +
+
+ + {statusQuery.data?.remote_commit_sha ?? "Unavailable"} + + Remote commit +
+
+ + {pendingQuery.data?.local_spec_hash} + + Local spec +
+
+ +
+ + + + + Resource + Operation + Path + + + + {changes.length === 0 ? ( + + + No pending changes + + + ) : ( + changes.map((change) => ( + + + { + setSelectedSourceIds((current) => { + const next = new Set(current) + if (checked === true) { + next.add(change.source_id) + } else { + next.delete(change.source_id) + } + return next + }) + }} + /> + + +
+ {change.title ?? change.source_id} +
+
+ {change.alias ?? change.source_id} +
+
+ + + {operationLabels[change.operation]} + + + + {change.source_path} + +
+ )) + )} +
+
+
+ +
+
+ + setMessage(event.target.value)} + /> +
+
+ + setBranch(event.target.value)} + /> +
+
+ + +
+
+ + )} +
+ +
+
+ +

Recent changesets

+
+
+ {(changesetsQuery.data ?? []).length === 0 ? ( +
+ No changesets +
+ ) : ( + (changesetsQuery.data ?? []).map((changeset) => ( +
+
+
+ {changeset.title} +
+
+ + {changeset.selected_paths.length} files + {changeset.status} +
+
+ +
+ )) + )} +
+
+
+ ) +} diff --git a/frontend/src/components/settings/settings-modal.tsx b/frontend/src/components/settings/settings-modal.tsx index cbf2a970e6..6e8b5969ad 100644 --- a/frontend/src/components/settings/settings-modal.tsx +++ b/frontend/src/components/settings/settings-modal.tsx @@ -175,15 +175,15 @@ function SettingsModalContent() { const showSyncNav = hasEntitlement("git_sync") return ( - + Settings Manage your account and workspace settings -
+
{/* Left nav panel */} -
+
Account @@ -261,7 +261,7 @@ function SettingsModalContent() {
{/* Right content panel */} -
+
{displayedSection === "profile" ? ( ) : workspaceId ? ( diff --git a/frontend/src/components/settings/workspace-sync-settings.tsx b/frontend/src/components/settings/workspace-sync-settings.tsx index 9750a790c3..9fa73e80d9 100644 --- a/frontend/src/components/settings/workspace-sync-settings.tsx +++ b/frontend/src/components/settings/workspace-sync-settings.tsx @@ -6,6 +6,7 @@ import { useState } from "react" import { useForm } from "react-hook-form" import { z } from "zod" import type { WorkspaceRead } from "@/client" +import { WorkspaceSyncStaging } from "@/components/organization/workspace-sync-staging" import { Button } from "@/components/ui/button" import { Form, @@ -93,34 +94,41 @@ export function WorkspaceSyncSettings({ {persistedGitUrl && ( -
-
-
-
Workflow synchronization
-

- Pull workflow definitions from your Git repository into this - workspace + <> +

+
+
+
+ Workflow synchronization +
+

+ Pull workflow definitions from your Git repository into this + workspace +

+
+ +
+ +
+

• Select a commit SHA to pull specific workflow versions

+

+ • All changes are atomic - either all workflows import or none + do

-
-
-

• Select a commit SHA to pull specific workflow versions

-

- • All changes are atomic - either all workflows import or none do -

-
-
+ + )} ({ + queryKey: ["workspace-sync-status", workspaceId], + queryFn: async () => { + if (!workspaceId) { + throw new Error("Workspace ID is required") + } + return await workflowsGetWorkspaceSyncStatus({ workspaceId }) + }, + enabled: !!workspaceId && options?.enabled !== false, + }) +} + +/** + * Fetches local workspace changes that can be exported to Git. + */ +export function useWorkspaceSyncPendingChanges( + workspaceId: string | undefined, + options?: { enabled?: boolean } +) { + return useQuery({ + queryKey: ["workspace-sync-pending", workspaceId], + queryFn: async () => { + if (!workspaceId) { + throw new Error("Workspace ID is required") + } + return await workflowsListWorkspaceSyncPendingChanges({ workspaceId }) + }, + enabled: !!workspaceId && options?.enabled !== false, + }) +} + +/** + * Fetches recent workspace sync changesets for review and re-selection. + */ +export function useWorkspaceSyncChangesets( + workspaceId: string | undefined, + options?: { enabled?: boolean; limit?: number } +) { + return useQuery({ + queryKey: ["workspace-sync-changesets", workspaceId, options?.limit ?? 20], + queryFn: async () => { + if (!workspaceId) { + throw new Error("Workspace ID is required") + } + return await workflowsListWorkspaceSyncChangesets({ + workspaceId, + limit: options?.limit ?? 20, + }) + }, + enabled: !!workspaceId && options?.enabled !== false, + }) +} + +/** + * Provides mutations for creating and exporting workspace sync changesets. + */ +export function useWorkspaceSyncChangesetActions( + workspaceId: string | undefined +) { + const queryClient = useQueryClient() + const invalidateSyncQueries = () => { + queryClient.invalidateQueries({ + queryKey: ["workspace-sync-status", workspaceId], + }) + queryClient.invalidateQueries({ + queryKey: ["workspace-sync-pending", workspaceId], + }) + queryClient.invalidateQueries({ + queryKey: ["workspace-sync-changesets", workspaceId], + }) + } + + const createChangeset = useMutation({ + mutationFn: async (requestBody) => { + if (!workspaceId) { + throw new Error("Workspace ID is required") + } + return await workflowsCreateWorkspaceSyncChangeset({ + workspaceId, + requestBody, + }) + }, + onSuccess: invalidateSyncQueries, + }) + + const exportChangeset = useMutation< + WorkspaceSyncExportResult, + Error, + { changesetId: string; requestBody: ChangeSetExport } + >({ + mutationFn: async ({ changesetId, requestBody }) => { + if (!workspaceId) { + throw new Error("Workspace ID is required") + } + return await workflowsExportWorkspaceSyncChangeset({ + workspaceId, + changesetId, + requestBody, + }) + }, + onSuccess: invalidateSyncQueries, + }) + + return { + createChangeset, + exportChangeset, + } +} diff --git a/tests/unit/api/test_api_workflow_store_publish.py b/tests/unit/api/test_api_workflow_store_publish.py index 6ee53d9d89..2dce19cf95 100644 --- a/tests/unit/api/test_api_workflow_store_publish.py +++ b/tests/unit/api/test_api_workflow_store_publish.py @@ -10,8 +10,10 @@ from tracecat.auth.types import Role from tracecat.exceptions import TracecatValidationError from tracecat.registry.repositories.schemas import GitBranchInfo +from tracecat.sync import CommitInfo, PushStatus from tracecat.vcs.github.app import GitHubAppError from tracecat.workflow.store.schemas import WorkflowDslPublishResult +from tracecat.workspace_sync.schemas import ChangeSetRead, WorkspaceSyncExportResult def _sample_dsl_content() -> dict[str, object]: @@ -186,7 +188,9 @@ async def test_list_workflow_branches_success( """Test GET /workflows/sync/branches returns branch list.""" with ( patch("tracecat.workflow.store.router.WorkspaceService") as mock_workspace_cls, - patch("tracecat.workflow.store.router.WorkflowSyncService") as mock_sync_cls, + patch( + "tracecat.workflow.store.router.WorkspaceGitSyncService" + ) as mock_sync_cls, ): mock_workspace_svc = AsyncMock() mock_workspace = Mock() @@ -246,7 +250,9 @@ async def test_list_workflow_branches_github_error_returns_400( """Test GET /workflows/sync/branches maps GitHub errors to 400.""" with ( patch("tracecat.workflow.store.router.WorkspaceService") as mock_workspace_cls, - patch("tracecat.workflow.store.router.WorkflowSyncService") as mock_sync_cls, + patch( + "tracecat.workflow.store.router.WorkspaceGitSyncService" + ) as mock_sync_cls, ): mock_workspace_svc = AsyncMock() mock_workspace = Mock() @@ -269,3 +275,111 @@ async def test_list_workflow_branches_github_error_returns_400( assert response.status_code == status.HTTP_400_BAD_REQUEST assert "Unable to access repository" in response.json()["detail"] + + +@pytest.mark.anyio +async def test_workspace_sync_changeset_routes( + client: TestClient, + test_admin_role: Role, +) -> None: + changeset_id = uuid.uuid4() + changeset = ChangeSetRead( + id=changeset_id, + title="Export workflow", + selected_resources=[ + { + "resource_type": "workflow", + "source_id": "detect-okta-risk", + "source_path": "workflows/detect-okta-risk/definition.yml", + } + ], + selected_paths=[ + "tracecat.json", + "workflows/detect-okta-risk/definition.yml", + ], + validation_status="valid", + validation_result={}, + status="validated", + ) + + with patch("tracecat.workflow.store.router.WorkspaceGitSyncService") as sync_cls: + sync_svc = AsyncMock() + sync_svc.create_changeset.return_value = changeset + sync_svc.get_changeset.return_value = changeset + sync_svc.list_changesets.return_value = [changeset] + sync_svc.export_changeset.return_value = WorkspaceSyncExportResult( + changeset_id=changeset_id, + commit=CommitInfo( + status=PushStatus.COMMITTED, + sha="c" * 40, + ref="sync/detect-okta-risk", + base_ref="main", + pr_url="https://github.com/test-org/test-repo/pull/1", + pr_number=1, + pr_reused=False, + message="Committed workspace sync changes.", + ), + ) + sync_cls.return_value = sync_svc + + create_response = client.post( + "/workflows/sync/changesets", + params={"workspace_id": str(test_admin_role.workspace_id)}, + json={ + "title": "Export workflow", + "resources": [ + { + "resource_type": "workflow", + "source_id": "detect-okta-risk", + "source_path": "workflows/detect-okta-risk/definition.yml", + } + ], + }, + ) + list_response = client.get( + "/workflows/sync/changesets", + params={"workspace_id": str(test_admin_role.workspace_id)}, + ) + get_response = client.get( + f"/workflows/sync/changesets/{changeset_id}", + params={"workspace_id": str(test_admin_role.workspace_id)}, + ) + export_response = client.post( + f"/workflows/sync/changesets/{changeset_id}/export", + params={"workspace_id": str(test_admin_role.workspace_id)}, + json={ + "message": "Export workflow", + "branch": "sync/detect-okta-risk", + "create_pr": True, + }, + ) + + assert create_response.status_code == status.HTTP_201_CREATED + assert list_response.status_code == status.HTTP_200_OK + assert get_response.status_code == status.HTTP_200_OK + assert export_response.status_code == status.HTTP_200_OK + assert create_response.json()["selected_resources"][0]["source_id"] == ( + "detect-okta-risk" + ) + assert list_response.json()[0]["id"] == str(changeset_id) + assert get_response.json()["status"] == "validated" + assert export_response.json()["commit"]["pr_number"] == 1 + + +@pytest.mark.anyio +async def test_workspace_sync_changeset_export_invalid_branch_returns_400( + client: TestClient, + test_admin_role: Role, +) -> None: + response = client.post( + f"/workflows/sync/changesets/{uuid.uuid4()}/export", + params={"workspace_id": str(test_admin_role.workspace_id)}, + json={ + "message": "Export workflow", + "branch": "refs/heads/sync/detect-okta-risk", + "create_pr": True, + }, + ) + + assert response.status_code == status.HTTP_400_BAD_REQUEST + assert "short branch name" in response.json()["detail"] diff --git a/tests/unit/test_workflow_store_service.py b/tests/unit/test_workflow_store_service.py index 72673f9f21..1310093902 100644 --- a/tests/unit/test_workflow_store_service.py +++ b/tests/unit/test_workflow_store_service.py @@ -9,13 +9,11 @@ from tracecat.auth.types import Role from tracecat.authz.scopes import SERVICE_PRINCIPAL_SCOPES -from tracecat.cases.enums import CaseEventType from tracecat.db.models import Workflow from tracecat.dsl.common import DSLEntrypoint, DSLInput from tracecat.dsl.schemas import ActionStatement from tracecat.identifiers.workflow import WorkflowUUID -from tracecat.sync import PushStatus -from tracecat.workflow.store.schemas import WorkflowDslPublish +from tracecat.workflow.store.schemas import WorkflowDslPublish, WorkflowDslPublishResult from tracecat.workflow.store.service import WorkflowStoreService @@ -69,7 +67,7 @@ def _workflow_fixture( @pytest.mark.anyio -async def test_publish_workflow_omits_inert_case_trigger( +async def test_publish_workflow_uses_workspace_sync_exporter( workflow_store_service: WorkflowStoreService, sample_dsl: DSLInput, ) -> None: @@ -83,44 +81,40 @@ async def test_publish_workflow_omits_inert_case_trigger( ), ) - with ( - patch("tracecat.workflow.store.service.WorkspaceService") as workspace_cls, - patch("tracecat.workflow.store.service.WorkflowSyncService") as sync_cls, - ): - workspace_service = AsyncMock() - workspace_service.get_workspace.return_value = SimpleNamespace( - settings={"git_repo_url": "git+ssh://git@github.com/test-org/test-repo.git"} - ) - workspace_cls.return_value = workspace_service - + with patch("tracecat.workflow.store.service.WorkspaceGitSyncService") as sync_cls: sync_service = AsyncMock() - sync_service.push.return_value = SimpleNamespace( - status=PushStatus.NO_OP, - sha="abc123", - ref="feature/test", - base_ref="main", - pr_url=None, - pr_number=None, - pr_reused=False, - message="No changes", + sync_service.export_workflow_publish_result.return_value = ( + WorkflowDslPublishResult( + status="no_op", + commit_sha="abc123", + branch="feature/test", + base_branch="main", + pr_url=None, + pr_number=None, + pr_reused=False, + message="No changes", + ) ) sync_cls.return_value = sync_service - await workflow_store_service.publish_workflow_dsl( + result = await workflow_store_service.publish_workflow_dsl( workflow_id=workflow_id, dsl=sample_dsl, params=WorkflowDslPublish(branch="feature/test", create_pr=False), workflow=cast(Workflow, workflow), ) - push_obj = sync_service.push.call_args.kwargs["objects"][0] - assert push_obj.data.case_trigger is None - # include_headers must be exported so it survives a store round-trip - assert push_obj.data.webhook.include_headers is True + assert result.status == "no_op" + sync_service.export_workflow_publish_result.assert_awaited_once() + call = sync_service.export_workflow_publish_result.call_args.kwargs + assert call["workflow"] is workflow + assert call["dsl"] is sample_dsl + assert call["options"].branch == "feature/test" + assert call["options"].create_pr is False @pytest.mark.anyio -async def test_publish_workflow_includes_configured_case_trigger( +async def test_publish_workflow_legacy_mode_uses_temp_branch_and_pr( workflow_store_service: WorkflowStoreService, sample_dsl: DSLInput, ) -> None: @@ -128,44 +122,35 @@ async def test_publish_workflow_includes_configured_case_trigger( workflow = _workflow_fixture( workflow_id, case_trigger=SimpleNamespace( - status="online", - event_types=[CaseEventType.CASE_CREATED.value], - tag_filters=["phishing"], + status="offline", + event_types=[], + tag_filters=[], ), ) - with ( - patch("tracecat.workflow.store.service.WorkspaceService") as workspace_cls, - patch("tracecat.workflow.store.service.WorkflowSyncService") as sync_cls, - ): - workspace_service = AsyncMock() - workspace_service.get_workspace.return_value = SimpleNamespace( - settings={"git_repo_url": "git+ssh://git@github.com/test-org/test-repo.git"} - ) - workspace_cls.return_value = workspace_service - + with patch("tracecat.workflow.store.service.WorkspaceGitSyncService") as sync_cls: sync_service = AsyncMock() - sync_service.push.return_value = SimpleNamespace( - status=PushStatus.NO_OP, - sha="abc123", - ref="feature/test", - base_ref="main", - pr_url=None, - pr_number=None, - pr_reused=False, - message="No changes", + sync_service.export_workflow_publish_result.return_value = ( + WorkflowDslPublishResult( + status="no_op", + commit_sha="abc123", + branch="tracecat-sync-20260605-000000", + base_branch="main", + pr_url=None, + pr_number=None, + pr_reused=False, + message="No changes", + ) ) sync_cls.return_value = sync_service await workflow_store_service.publish_workflow_dsl( workflow_id=workflow_id, dsl=sample_dsl, - params=WorkflowDslPublish(branch="feature/test", create_pr=False), + params=WorkflowDslPublish(branch=None, create_pr=False), workflow=cast(Workflow, workflow), ) - push_obj = sync_service.push.call_args.kwargs["objects"][0] - assert push_obj.data.case_trigger is not None - assert push_obj.data.case_trigger.status == "online" - assert push_obj.data.case_trigger.event_types == [CaseEventType.CASE_CREATED] - assert push_obj.data.case_trigger.tag_filters == ["phishing"] + call = sync_service.export_workflow_publish_result.call_args.kwargs + assert call["options"].branch.startswith("tracecat-sync-") + assert call["options"].create_pr is True diff --git a/tests/unit/test_workflow_sync_service.py b/tests/unit/test_workflow_sync_service.py deleted file mode 100644 index 3482a6f11c..0000000000 --- a/tests/unit/test_workflow_sync_service.py +++ /dev/null @@ -1,1275 +0,0 @@ -"""Tests for WorkflowSyncService functionality.""" - -import base64 -import uuid -from unittest.mock import AsyncMock, Mock, patch - -import pytest -import yaml -from github.GithubException import GithubException - -from tracecat.auth.types import Role -from tracecat.authz.scopes import SERVICE_PRINCIPAL_SCOPES -from tracecat.dsl.common import DSLEntrypoint, DSLInput -from tracecat.dsl.schemas import ActionStatement -from tracecat.git.types import GitUrl -from tracecat.sync import Author, PushObject, PushOptions, PushStatus -from tracecat.workflow.store.import_service import WorkflowImportService -from tracecat.workflow.store.schemas import RemoteWorkflowDefinition -from tracecat.workflow.store.sync import WorkflowSyncService - - -@pytest.fixture -def workspace_id(): - """Test workspace ID.""" - return uuid.UUID("550e8400-e29b-41d4-a716-446655440000") - - -@pytest.fixture -def git_url(): - """Test Git URL.""" - return GitUrl(host="github.com", org="test-org", repo="test-repo", ref="main") - - -@pytest.fixture -def sample_workflow(): - """Sample workflow DSL.""" - return DSLInput( - title="Test Workflow", - description="A test workflow", - entrypoint=DSLEntrypoint(ref="start", expects={}), - actions=[ - ActionStatement( - ref="start", - action="core.transform.passthrough", - args={"value": "test"}, - ) - ], - ) - - -@pytest.fixture -def sample_remote_workflow(sample_workflow): - """Sample RemoteWorkflowDefinition.""" - return RemoteWorkflowDefinition( - id="wf_123abc", - alias="test-workflow", - definition=sample_workflow, - ) - - -@pytest.fixture -def sample_remote_workflow_with_folder(sample_workflow): - """Sample RemoteWorkflowDefinition with folder_path.""" - return RemoteWorkflowDefinition( - id="wf_folder123", - alias="test-workflow-with-folder", - folder_path="/security/detections/", - definition=sample_workflow, - ) - - -@pytest.fixture -def organization_id(): - """Test organization ID.""" - return uuid.UUID("550e8400-e29b-41d4-a716-446655440001") - - -@pytest.fixture -def workflow_sync_service(workspace_id, organization_id): - """WorkflowSyncService instance for testing.""" - # Use a mock session for unit tests - mock_session = AsyncMock() - role = Role( - type="service", - service_id="tracecat-api", - workspace_id=workspace_id, - organization_id=organization_id, - scopes=SERVICE_PRINCIPAL_SCOPES["tracecat-api"], - ) - return WorkflowSyncService(session=mock_session, role=role) - - -@pytest.fixture -def workflow_import_service(workspace_id, organization_id): - """WorkflowImportService instance for testing.""" - # Use a mock session for unit tests - mock_session = AsyncMock() - role = Role( - type="service", - service_id="tracecat-api", - workspace_id=workspace_id, - organization_id=organization_id, - scopes=SERVICE_PRINCIPAL_SCOPES["tracecat-api"], - ) - return WorkflowImportService(session=mock_session, role=role) - - -class TestWorkflowSyncService: - """Tests for WorkflowSyncService.""" - - @pytest.mark.anyio - async def test_pull_requires_commit_sha(self, workflow_sync_service, git_url): - """Test that pull returns error result when commit_sha is missing.""" - result = await workflow_sync_service.pull(url=git_url) - - assert result.success is False - assert result.commit_sha == "" - assert result.workflows_found == 0 - assert result.workflows_imported == 0 - assert result.message == "commit_sha is required" - assert len(result.diagnostics) == 1 - assert result.diagnostics[0].error_type == "validation" - assert "commit_sha is required" in result.diagnostics[0].message - - @pytest.mark.anyio - async def test_push_workflows_success( - self, workflow_sync_service, git_url, sample_remote_workflow - ): - """Test successful workflow push to Git repository.""" - push_obj = PushObject( - data=sample_remote_workflow, path="workflows/test-workflow.yml" - ) - author = Author(name="Test User", email="test@example.com") - options = PushOptions( - message="Update workflows", author=author, create_pr=False - ) - - mock_repo = Mock() - mock_branch = Mock() - mock_branch.commit.sha = "abc123def" - mock_repo.default_branch = "main" - mock_repo.get_branch.return_value = mock_branch - mock_repo.create_git_ref = Mock() - mock_repo.create_file = Mock() - # Mock get_contents to raise 404 (file doesn't exist) - mock_repo.get_contents.side_effect = GithubException( - 404, {"message": "Not Found"}, {} - ) - - mock_github_client = Mock() - mock_github_client.get_repo.return_value = mock_repo - mock_github_client.close = Mock() - - with ( - patch( - "tracecat.workflow.store.sync.GitHubAppService" - ) as mock_gh_service_class, - patch("asyncio.to_thread") as mock_to_thread, - ): - mock_gh_service = AsyncMock() - mock_gh_service.get_github_client_for_repo = AsyncMock( - return_value=mock_github_client - ) - mock_gh_service_class.return_value = mock_gh_service - - # Mock asyncio.to_thread to return the direct result (not coroutine) - async def mock_to_thread_impl(func, *args, **kwargs): - return func(*args, **kwargs) - - mock_to_thread.side_effect = mock_to_thread_impl - - result = await workflow_sync_service.push( - objects=[push_obj], url=git_url, options=options - ) - - assert result.status == PushStatus.COMMITTED - assert result.sha == "abc123def" - assert result.ref.startswith("tracecat-sync-") - - # Verify GitHub API calls were made - mock_repo.create_git_ref.assert_called_once() - mock_repo.create_file.assert_called_once() - - @pytest.mark.anyio - async def test_push_objects_with_stable_path( - self, workflow_sync_service, git_url, sample_remote_workflow - ): - """Test push with explicit stable path using PushObject.""" - stable_path = "workflows/wf_123abc.yml" - push_item = PushObject(data=sample_remote_workflow, path=stable_path) - author = Author(name="Test User", email="test@example.com") - options = PushOptions( - message="Update workflows", author=author, create_pr=False - ) - - mock_repo = Mock() - mock_branch = Mock() - mock_branch.commit.sha = "abc123def" - mock_repo.default_branch = "main" - mock_repo.get_branch.return_value = mock_branch - mock_repo.create_git_ref = Mock() - mock_repo.create_file = Mock() - # Mock get_contents to raise 404 (file doesn't exist) - mock_repo.get_contents.side_effect = GithubException( - 404, {"message": "Not Found"}, {} - ) - - mock_github_client = Mock() - mock_github_client.get_repo.return_value = mock_repo - mock_github_client.close = Mock() - - with ( - patch( - "tracecat.workflow.store.sync.GitHubAppService" - ) as mock_gh_service_class, - patch("asyncio.to_thread") as mock_to_thread, - ): - mock_gh_service = AsyncMock() - mock_gh_service.get_github_client_for_repo = AsyncMock( - return_value=mock_github_client - ) - mock_gh_service_class.return_value = mock_gh_service - - # Mock asyncio.to_thread to return the direct result (not coroutine) - async def mock_to_thread_impl(func, *args, **kwargs): - return func(*args, **kwargs) - - mock_to_thread.side_effect = mock_to_thread_impl - - result = await workflow_sync_service.push( - objects=[push_item], url=git_url, options=options - ) - - assert result.status == PushStatus.COMMITTED - assert result.sha == "abc123def" - assert result.ref.startswith("tracecat-sync-") - - # Verify exact path was used - mock_repo.create_file.assert_called_once() - call_args = mock_repo.create_file.call_args - assert call_args.kwargs["path"] == stable_path - - @pytest.mark.anyio - async def test_push_workflows_with_pr( - self, workflow_sync_service, git_url, sample_remote_workflow - ): - """Test workflow push with pull request creation.""" - push_obj = PushObject( - data=sample_remote_workflow, path="workflows/test-workflow.yml" - ) - author = Author(name="Test User", email="test@example.com") - options = PushOptions(message="Update workflows", author=author, create_pr=True) - - mock_repo = Mock() - mock_branch = Mock() - mock_branch.commit.sha = "abc123def" - mock_repo.default_branch = "main" - mock_repo.get_branch.return_value = mock_branch - mock_repo.create_git_ref = Mock() - mock_repo.create_file = Mock() - # Mock get_contents to raise 404 (file doesn't exist) - mock_repo.get_contents.side_effect = GithubException( - 404, {"message": "Not Found"}, {} - ) - - mock_pr = Mock() - mock_pr.html_url = "https://github.com/test-org/test-repo/pull/123" - mock_pr.number = 123 - mock_repo.create_pull.return_value = mock_pr - - mock_github_client = Mock() - mock_github_client.get_repo.return_value = mock_repo - mock_github_client.close = Mock() - - with ( - patch( - "tracecat.workflow.store.sync.GitHubAppService" - ) as mock_gh_service_class, - patch("asyncio.to_thread") as mock_to_thread, - patch( - "tracecat.workflow.store.sync.WorkspaceService" - ) as mock_ws_service_class, - ): - mock_gh_service = AsyncMock() - mock_gh_service.get_github_client_for_repo = AsyncMock( - return_value=mock_github_client - ) - mock_gh_service_class.return_value = mock_gh_service - - # Mock WorkspaceService - mock_ws_service = AsyncMock() - mock_workspace = Mock() - mock_workspace.name = "Test Workspace" - mock_ws_service.get_workspace.return_value = mock_workspace - mock_ws_service_class.return_value = mock_ws_service - - # Mock asyncio.to_thread to return the direct result (not coroutine) - async def mock_to_thread_impl(func, *args, **kwargs): - return func(*args, **kwargs) - - mock_to_thread.side_effect = mock_to_thread_impl - - result = await workflow_sync_service.push( - objects=[push_obj], url=git_url, options=options - ) - - assert result.status == PushStatus.COMMITTED - assert result.sha == "abc123def" - assert result.ref.startswith("tracecat-sync-") - - # Verify PR was created - mock_repo.create_pull.assert_called_once() - - @pytest.mark.anyio - async def test_push_workflows_empty_objects(self, workflow_sync_service, git_url): - """Test push fails with empty objects list.""" - author = Author(name="Test User", email="test@example.com") - options = PushOptions(message="Update workflows", author=author) - - with pytest.raises( - ValueError, match="We only support pushing one workflow object at a time" - ): - await workflow_sync_service.push(objects=[], url=git_url, options=options) - - @pytest.mark.anyio - async def test_push_objects_empty_objects(self, workflow_sync_service, git_url): - """Test push fails with empty objects list.""" - author = Author(name="Test User", email="test@example.com") - options = PushOptions(message="Update workflows", author=author) - - with pytest.raises( - ValueError, match="We only support pushing one workflow object at a time" - ): - await workflow_sync_service.push(objects=[], url=git_url, options=options) - - @pytest.mark.anyio - async def test_push_workflows_github_failure( - self, workflow_sync_service, git_url, sample_remote_workflow - ): - """Test push handles GitHub API failures.""" - push_obj = PushObject( - data=sample_remote_workflow, path="workflows/test-workflow.yml" - ) - author = Author(name="Test User", email="test@example.com") - options = PushOptions(message="Update workflows", author=author) - - with ( - patch( - "tracecat.workflow.store.sync.GitHubAppService" - ) as mock_gh_service_class, - ): - mock_gh_service = AsyncMock() - mock_gh_service.get_github_client_for_repo.side_effect = Exception( - "GitHub API error" - ) - mock_gh_service_class.return_value = mock_gh_service - - with pytest.raises(Exception, match="GitHub API error"): - await workflow_sync_service.push( - objects=[push_obj], url=git_url, options=options - ) - - @pytest.mark.anyio - async def test_filename_generation_backward_compatibility( - self, workflow_sync_service - ): - """Test workflow filename generation from title (backward compatibility).""" - # Test with normal title - workflow1 = DSLInput( - title="My Test Workflow", - description="Test", - entrypoint=DSLEntrypoint(ref="start", expects={}), - actions=[ActionStatement(ref="start", action="core.noop", args={})], - ) - - # Create remote workflow definition - remote_workflow = RemoteWorkflowDefinition( - id="wf_test123", - alias="my-test-workflow", - definition=workflow1, - ) - - # Use title-based path for backward compatibility test - push_obj = PushObject( - data=remote_workflow, path="workflows/my-test-workflow.yaml" - ) - - author = Author(name="Test User", email="test@example.com") - options = PushOptions(message="Test", author=author) - git_url = GitUrl(host="github.com", org="test", repo="test") - - mock_repo = Mock() - mock_branch = Mock() - mock_branch.commit.sha = "abc123def" - mock_repo.default_branch = "main" - mock_repo.get_branch.return_value = mock_branch - mock_repo.create_git_ref = Mock() - mock_repo.create_file = Mock() - # Mock get_contents to raise 404 (file doesn't exist) - mock_repo.get_contents.side_effect = GithubException( - 404, {"message": "Not Found"}, {} - ) - - mock_github_client = Mock() - mock_github_client.get_repo.return_value = mock_repo - mock_github_client.close = Mock() - - with ( - patch( - "tracecat.workflow.store.sync.GitHubAppService" - ) as mock_gh_service_class, - patch("asyncio.to_thread") as mock_to_thread, - ): - mock_gh_service = AsyncMock() - mock_gh_service.get_github_client_for_repo = AsyncMock( - return_value=mock_github_client - ) - mock_gh_service_class.return_value = mock_gh_service - - # Mock asyncio.to_thread to return the direct result (not coroutine) - async def mock_to_thread_impl(func, *args, **kwargs): - return func(*args, **kwargs) - - mock_to_thread.side_effect = mock_to_thread_impl - - await workflow_sync_service.push( - objects=[push_obj], url=git_url, options=options - ) - - # Verify the file was created with sanitized filename (backward compatibility) - mock_repo.create_file.assert_called_once() - call_args = mock_repo.create_file.call_args - file_path = call_args.kwargs["path"] - assert file_path == "workflows/my-test-workflow.yaml" - - @pytest.mark.anyio - async def test_push_target_branch_commits_to_existing_branch( - self, workflow_sync_service, git_url, sample_remote_workflow - ): - """Test branch-target mode commits directly to an existing branch.""" - push_obj = PushObject( - data=sample_remote_workflow, path="workflows/test-workflow.yml" - ) - author = Author(name="Test User", email="test@example.com") - options = PushOptions( - message="Update workflows", - author=author, - create_pr=False, - branch="feature/shared-workflow", - ) - - mock_repo = Mock() - base_branch = Mock() - base_branch.commit.sha = "base123" - target_branch = Mock() - target_branch.commit.sha = "target123" - mock_repo.default_branch = "main" - mock_repo.get_branch.side_effect = lambda name: ( - base_branch if name == "main" else target_branch - ) - mock_repo.get_contents.side_effect = GithubException( - 404, {"message": "Not Found"}, {} - ) - mock_repo.create_file = Mock() - mock_repo.create_git_ref = Mock() - - mock_github_client = Mock() - mock_github_client.get_repo.return_value = mock_repo - mock_github_client.close = Mock() - - with ( - patch( - "tracecat.workflow.store.sync.GitHubAppService" - ) as mock_gh_service_class, - patch("asyncio.to_thread") as mock_to_thread, - ): - mock_gh_service = AsyncMock() - mock_gh_service.get_github_client_for_repo = AsyncMock( - return_value=mock_github_client - ) - mock_gh_service_class.return_value = mock_gh_service - - async def mock_to_thread_impl(func, *args, **kwargs): - return func(*args, **kwargs) - - mock_to_thread.side_effect = mock_to_thread_impl - - result = await workflow_sync_service.push( - objects=[push_obj], url=git_url, options=options - ) - - assert result.status == PushStatus.COMMITTED - assert result.sha == "target123" - assert result.ref == "feature/shared-workflow" - assert result.base_ref == "main" - assert result.pr_url is None - assert result.pr_number is None - assert result.pr_reused is False - mock_repo.create_git_ref.assert_not_called() - mock_repo.create_file.assert_called_once() - - @pytest.mark.anyio - async def test_push_target_branch_creates_missing_branch_from_base( - self, workflow_sync_service, git_url, sample_remote_workflow - ): - """Test branch-target mode creates target branch when it is missing.""" - push_obj = PushObject( - data=sample_remote_workflow, path="workflows/test-workflow.yml" - ) - author = Author(name="Test User", email="test@example.com") - options = PushOptions( - message="Update workflows", - author=author, - create_pr=False, - branch="feature/new-workflow", - ) - - mock_repo = Mock() - base_branch = Mock() - base_branch.commit.sha = "base123" - target_branch = Mock() - target_branch.commit.sha = "target123" - mock_repo.default_branch = "main" - - target_branch_lookup_count = 0 - - def get_branch(name: str): - nonlocal target_branch_lookup_count - if name == "main": - return base_branch - if name == "feature/new-workflow": - target_branch_lookup_count += 1 - if target_branch_lookup_count == 1: - raise GithubException(404, {"message": "Not Found"}, {}) - return target_branch - raise AssertionError(f"Unexpected branch lookup: {name}") - - mock_repo.get_branch.side_effect = get_branch - mock_repo.get_contents.side_effect = GithubException( - 404, {"message": "Not Found"}, {} - ) - mock_repo.create_file = Mock() - mock_repo.create_git_ref = Mock() - - mock_github_client = Mock() - mock_github_client.get_repo.return_value = mock_repo - mock_github_client.close = Mock() - - with ( - patch( - "tracecat.workflow.store.sync.GitHubAppService" - ) as mock_gh_service_class, - patch("asyncio.to_thread") as mock_to_thread, - ): - mock_gh_service = AsyncMock() - mock_gh_service.get_github_client_for_repo = AsyncMock( - return_value=mock_github_client - ) - mock_gh_service_class.return_value = mock_gh_service - - async def mock_to_thread_impl(func, *args, **kwargs): - return func(*args, **kwargs) - - mock_to_thread.side_effect = mock_to_thread_impl - - result = await workflow_sync_service.push( - objects=[push_obj], url=git_url, options=options - ) - - assert result.status == PushStatus.COMMITTED - assert result.ref == "feature/new-workflow" - assert result.base_ref == "main" - mock_repo.create_git_ref.assert_called_once_with( - ref="refs/heads/feature/new-workflow", - sha="base123", - ) - - @pytest.mark.anyio - async def test_push_target_branch_defaults_base_to_url_ref( - self, workflow_sync_service, sample_remote_workflow - ): - """Test branch-target mode uses URL ref as base when pr_base_branch is unset.""" - push_obj = PushObject( - data=sample_remote_workflow, path="workflows/test-workflow.yml" - ) - author = Author(name="Test User", email="test@example.com") - options = PushOptions( - message="Update workflows", - author=author, - create_pr=False, - branch="feature/new-workflow", - ) - git_url = GitUrl( - host="github.com", org="test-org", repo="test-repo", ref="release" - ) - - mock_repo = Mock() - release_branch = Mock() - release_branch.commit.sha = "release123" - target_branch = Mock() - target_branch.commit.sha = "target123" - mock_repo.default_branch = "main" - - def get_branch(name: str): - if name == "release": - return release_branch - if name == "feature/new-workflow": - return target_branch - raise AssertionError(f"Unexpected branch lookup: {name}") - - mock_repo.get_branch.side_effect = get_branch - mock_repo.get_contents.side_effect = GithubException( - 404, {"message": "Not Found"}, {} - ) - mock_repo.create_file = Mock() - mock_repo.create_git_ref = Mock() - - mock_github_client = Mock() - mock_github_client.get_repo.return_value = mock_repo - mock_github_client.close = Mock() - - with ( - patch( - "tracecat.workflow.store.sync.GitHubAppService" - ) as mock_gh_service_class, - patch("asyncio.to_thread") as mock_to_thread, - ): - mock_gh_service = AsyncMock() - mock_gh_service.get_github_client_for_repo = AsyncMock( - return_value=mock_github_client - ) - mock_gh_service_class.return_value = mock_gh_service - - async def mock_to_thread_impl(func, *args, **kwargs): - return func(*args, **kwargs) - - mock_to_thread.side_effect = mock_to_thread_impl - - result = await workflow_sync_service.push( - objects=[push_obj], url=git_url, options=options - ) - - assert result.status == PushStatus.COMMITTED - assert result.base_ref == "release" - mock_repo.create_git_ref.assert_not_called() - - @pytest.mark.anyio - async def test_push_target_branch_noop_returns_no_op( - self, workflow_sync_service, git_url, sample_remote_workflow - ): - """Test branch-target mode returns no_op on identical file contents.""" - push_obj = PushObject( - data=sample_remote_workflow, path="workflows/test-workflow.yml" - ) - author = Author(name="Test User", email="test@example.com") - options = PushOptions( - message="Update workflows", - author=author, - create_pr=False, - branch="feature/shared-workflow", - ) - - expected_yaml = yaml.dump( - sample_remote_workflow.model_dump( - mode="json", exclude_none=True, exclude_unset=True - ), - sort_keys=False, - ) - - mock_repo = Mock() - base_branch = Mock() - base_branch.commit.sha = "base123" - target_branch = Mock() - target_branch.commit.sha = "target123" - mock_repo.default_branch = "main" - mock_repo.get_branch.side_effect = lambda name: ( - base_branch if name == "main" else target_branch - ) - - mock_contents = Mock() - mock_contents.path = "workflows/test-workflow.yml" - mock_contents.sha = "file123" - mock_contents.content = base64.b64encode(expected_yaml.encode("utf-8")).decode( - "utf-8" - ) - mock_repo.get_contents.return_value = mock_contents - mock_repo.update_file = Mock() - mock_repo.create_file = Mock() - - mock_github_client = Mock() - mock_github_client.get_repo.return_value = mock_repo - mock_github_client.close = Mock() - - with ( - patch( - "tracecat.workflow.store.sync.GitHubAppService" - ) as mock_gh_service_class, - patch("asyncio.to_thread") as mock_to_thread, - ): - mock_gh_service = AsyncMock() - mock_gh_service.get_github_client_for_repo = AsyncMock( - return_value=mock_github_client - ) - mock_gh_service_class.return_value = mock_gh_service - - async def mock_to_thread_impl(func, *args, **kwargs): - return func(*args, **kwargs) - - mock_to_thread.side_effect = mock_to_thread_impl - - result = await workflow_sync_service.push( - objects=[push_obj], url=git_url, options=options - ) - - assert result.status == PushStatus.NO_OP - assert result.sha is None - assert result.ref == "feature/shared-workflow" - assert result.base_ref == "main" - mock_repo.update_file.assert_not_called() - mock_repo.create_file.assert_not_called() - - @pytest.mark.anyio - async def test_push_empty_branch_still_uses_target_branch_mode( - self, workflow_sync_service, git_url, sample_remote_workflow - ): - """Test empty-string branch does not fall back to legacy mode.""" - push_obj = PushObject( - data=sample_remote_workflow, path="workflows/test-workflow.yml" - ) - author = Author(name="Test User", email="test@example.com") - options = PushOptions( - message="Update workflows", - author=author, - create_pr=False, - branch="", - ) - - mock_repo = Mock() - mock_github_client = Mock() - mock_github_client.get_repo.return_value = mock_repo - mock_github_client.close = Mock() - - target_mode_result = Mock() - - with ( - patch( - "tracecat.workflow.store.sync.GitHubAppService" - ) as mock_gh_service_class, - patch("asyncio.to_thread") as mock_to_thread, - patch.object( - workflow_sync_service, - "_push_to_target_branch", - new=AsyncMock(return_value=target_mode_result), - ) as mock_target_mode, - patch.object( - workflow_sync_service, - "_push_legacy", - new=AsyncMock(return_value=Mock()), - ) as mock_legacy_mode, - ): - mock_gh_service = AsyncMock() - mock_gh_service.get_github_client_for_repo = AsyncMock( - return_value=mock_github_client - ) - mock_gh_service_class.return_value = mock_gh_service - - async def mock_to_thread_impl(func, *args, **kwargs): - return func(*args, **kwargs) - - mock_to_thread.side_effect = mock_to_thread_impl - - result = await workflow_sync_service.push( - objects=[push_obj], url=git_url, options=options - ) - - assert result is target_mode_result - mock_target_mode.assert_awaited_once() - mock_legacy_mode.assert_not_called() - - @pytest.mark.anyio - async def test_push_target_branch_create_pr_failure_returns_committed_result( - self, workflow_sync_service, git_url, sample_remote_workflow - ): - """Test successful commits are returned even when PR creation fails.""" - push_obj = PushObject( - data=sample_remote_workflow, path="workflows/test-workflow.yml" - ) - author = Author(name="Test User", email="test@example.com") - options = PushOptions( - message="Update workflows", - author=author, - create_pr=True, - branch="feature/shared-workflow", - ) - - mock_repo = Mock() - base_branch = Mock() - base_branch.commit.sha = "base123" - target_branch = Mock() - target_branch.commit.sha = "target123" - mock_repo.default_branch = "main" - mock_repo.get_branch.side_effect = lambda name: ( - base_branch if name == "main" else target_branch - ) - mock_repo.get_contents.side_effect = GithubException( - 404, {"message": "Not Found"}, {} - ) - mock_repo.create_file = Mock() - - mock_github_client = Mock() - mock_github_client.get_repo.return_value = mock_repo - mock_github_client.close = Mock() - - with ( - patch( - "tracecat.workflow.store.sync.GitHubAppService" - ) as mock_gh_service_class, - patch("asyncio.to_thread") as mock_to_thread, - patch.object( - workflow_sync_service, - "_upsert_pull_request", - new=AsyncMock( - side_effect=GithubException( - 422, {"message": "Validation Failed"}, {} - ) - ), - ) as mock_upsert_pr, - ): - mock_gh_service = AsyncMock() - mock_gh_service.get_github_client_for_repo = AsyncMock( - return_value=mock_github_client - ) - mock_gh_service_class.return_value = mock_gh_service - - async def mock_to_thread_impl(func, *args, **kwargs): - return func(*args, **kwargs) - - mock_to_thread.side_effect = mock_to_thread_impl - - result = await workflow_sync_service.push( - objects=[push_obj], url=git_url, options=options - ) - - assert result.status == PushStatus.COMMITTED - assert result.sha == "target123" - assert result.pr_url is None - assert result.pr_number is None - assert result.pr_reused is False - mock_upsert_pr.assert_awaited_once() - mock_repo.create_file.assert_called_once() - - @pytest.mark.anyio - async def test_push_target_branch_noop_pr_failure_returns_no_op_result( - self, workflow_sync_service, git_url, sample_remote_workflow - ): - """Test no-op publish still succeeds when PR creation fails.""" - push_obj = PushObject( - data=sample_remote_workflow, path="workflows/test-workflow.yml" - ) - author = Author(name="Test User", email="test@example.com") - options = PushOptions( - message="Update workflows", - author=author, - create_pr=True, - branch="feature/shared-workflow", - ) - - expected_yaml = yaml.dump( - sample_remote_workflow.model_dump( - mode="json", exclude_none=True, exclude_unset=True - ), - sort_keys=False, - ) - - mock_repo = Mock() - base_branch = Mock() - base_branch.commit.sha = "base123" - target_branch = Mock() - target_branch.commit.sha = "target123" - mock_repo.default_branch = "main" - mock_repo.get_branch.side_effect = lambda name: ( - base_branch if name == "main" else target_branch - ) - - mock_contents = Mock() - mock_contents.path = "workflows/test-workflow.yml" - mock_contents.sha = "file123" - mock_contents.content = base64.b64encode(expected_yaml.encode("utf-8")).decode( - "utf-8" - ) - mock_repo.get_contents.return_value = mock_contents - mock_repo.update_file = Mock() - mock_repo.create_file = Mock() - - mock_github_client = Mock() - mock_github_client.get_repo.return_value = mock_repo - mock_github_client.close = Mock() - - with ( - patch( - "tracecat.workflow.store.sync.GitHubAppService" - ) as mock_gh_service_class, - patch("asyncio.to_thread") as mock_to_thread, - patch.object( - workflow_sync_service, - "_upsert_pull_request", - new=AsyncMock( - side_effect=GithubException( - 422, {"message": "Validation Failed"}, {} - ) - ), - ) as mock_upsert_pr, - ): - mock_gh_service = AsyncMock() - mock_gh_service.get_github_client_for_repo = AsyncMock( - return_value=mock_github_client - ) - mock_gh_service_class.return_value = mock_gh_service - - async def mock_to_thread_impl(func, *args, **kwargs): - return func(*args, **kwargs) - - mock_to_thread.side_effect = mock_to_thread_impl - - result = await workflow_sync_service.push( - objects=[push_obj], url=git_url, options=options - ) - - assert result.status == PushStatus.NO_OP - assert result.sha is None - assert result.pr_url is None - assert result.pr_number is None - assert result.pr_reused is False - mock_upsert_pr.assert_awaited_once() - mock_repo.update_file.assert_not_called() - mock_repo.create_file.assert_not_called() - - @pytest.mark.anyio - async def test_push_target_branch_create_pr_creates_new_pull_request( - self, workflow_sync_service, git_url, sample_remote_workflow - ): - """Test branch-target mode creates a PR when requested and none exists.""" - push_obj = PushObject( - data=sample_remote_workflow, path="workflows/test-workflow.yml" - ) - author = Author(name="Test User", email="test@example.com") - options = PushOptions( - message="Update workflows", - author=author, - create_pr=True, - branch="feature/shared-workflow", - ) - - mock_repo = Mock() - base_branch = Mock() - base_branch.commit.sha = "base123" - target_branch = Mock() - target_branch.commit.sha = "target123" - mock_repo.default_branch = "main" - mock_repo.get_branch.side_effect = lambda name: ( - base_branch if name == "main" else target_branch - ) - mock_repo.get_contents.side_effect = GithubException( - 404, {"message": "Not Found"}, {} - ) - mock_repo.get_pulls.return_value = [] - mock_repo.create_file = Mock() - - mock_pr = Mock() - mock_pr.number = 456 - mock_pr.html_url = "https://github.com/test-org/test-repo/pull/456" - mock_repo.create_pull.return_value = mock_pr - - mock_github_client = Mock() - mock_github_client.get_repo.return_value = mock_repo - mock_github_client.close = Mock() - - with ( - patch( - "tracecat.workflow.store.sync.GitHubAppService" - ) as mock_gh_service_class, - patch("asyncio.to_thread") as mock_to_thread, - patch( - "tracecat.workflow.store.sync.WorkspaceService" - ) as mock_ws_service_class, - ): - mock_gh_service = AsyncMock() - mock_gh_service.get_github_client_for_repo = AsyncMock( - return_value=mock_github_client - ) - mock_gh_service_class.return_value = mock_gh_service - - mock_ws_service = AsyncMock() - mock_workspace = Mock() - mock_workspace.name = "Test Workspace" - mock_ws_service.get_workspace.return_value = mock_workspace - mock_ws_service_class.return_value = mock_ws_service - - async def mock_to_thread_impl(func, *args, **kwargs): - return func(*args, **kwargs) - - mock_to_thread.side_effect = mock_to_thread_impl - - result = await workflow_sync_service.push( - objects=[push_obj], url=git_url, options=options - ) - - assert result.status == PushStatus.COMMITTED - assert result.pr_reused is False - assert result.pr_number == 456 - assert result.pr_url == "https://github.com/test-org/test-repo/pull/456" - mock_repo.create_pull.assert_called_once() - - @pytest.mark.anyio - async def test_push_target_branch_create_pr_reuses_existing_pull_request( - self, workflow_sync_service, git_url, sample_remote_workflow - ): - """Test branch-target mode reuses existing open PR for same head/base.""" - push_obj = PushObject( - data=sample_remote_workflow, path="workflows/test-workflow.yml" - ) - author = Author(name="Test User", email="test@example.com") - options = PushOptions( - message="Update workflows", - author=author, - create_pr=True, - branch="feature/shared-workflow", - ) - - mock_repo = Mock() - base_branch = Mock() - base_branch.commit.sha = "base123" - target_branch = Mock() - target_branch.commit.sha = "target123" - mock_repo.default_branch = "main" - mock_repo.get_branch.side_effect = lambda name: ( - base_branch if name == "main" else target_branch - ) - mock_repo.get_contents.side_effect = GithubException( - 404, {"message": "Not Found"}, {} - ) - mock_repo.create_file = Mock() - - existing_pr = Mock() - existing_pr.number = 789 - existing_pr.html_url = "https://github.com/test-org/test-repo/pull/789" - mock_repo.get_pulls.return_value = [existing_pr] - mock_repo.create_pull = Mock() - - mock_github_client = Mock() - mock_github_client.get_repo.return_value = mock_repo - mock_github_client.close = Mock() - - with ( - patch( - "tracecat.workflow.store.sync.GitHubAppService" - ) as mock_gh_service_class, - patch("asyncio.to_thread") as mock_to_thread, - ): - mock_gh_service = AsyncMock() - mock_gh_service.get_github_client_for_repo = AsyncMock( - return_value=mock_github_client - ) - mock_gh_service_class.return_value = mock_gh_service - - async def mock_to_thread_impl(func, *args, **kwargs): - return func(*args, **kwargs) - - mock_to_thread.side_effect = mock_to_thread_impl - - result = await workflow_sync_service.push( - objects=[push_obj], url=git_url, options=options - ) - - assert result.status == PushStatus.COMMITTED - assert result.pr_reused is True - assert result.pr_number == 789 - assert result.pr_url == "https://github.com/test-org/test-repo/pull/789" - mock_repo.create_pull.assert_not_called() - - @pytest.mark.anyio - async def test_list_branches_success(self, workflow_sync_service, git_url): - """Test branch listing from GitHub repository.""" - mock_repo = Mock() - mock_repo.default_branch = "main" - - main_branch = Mock() - main_branch.name = "main" - feature_branch = Mock() - feature_branch.name = "feature/workflow-sync" - mock_repo.get_branches.return_value = [main_branch, feature_branch] - - mock_github_client = Mock() - mock_github_client.get_repo.return_value = mock_repo - mock_github_client.close = Mock() - - with ( - patch( - "tracecat.workflow.store.sync.GitHubAppService" - ) as mock_gh_service_class, - patch("asyncio.to_thread") as mock_to_thread, - ): - mock_gh_service = AsyncMock() - mock_gh_service.get_github_client_for_repo = AsyncMock( - return_value=mock_github_client - ) - mock_gh_service_class.return_value = mock_gh_service - - async def mock_to_thread_impl(func, *args, **kwargs): - return func(*args, **kwargs) - - mock_to_thread.side_effect = mock_to_thread_impl - - branches = await workflow_sync_service.list_branches(url=git_url, limit=10) - - assert len(branches) == 2 - assert branches[0].name == "main" - assert branches[0].is_default is True - assert branches[1].name == "feature/workflow-sync" - assert branches[1].is_default is False - - -class TestWorkflowImportServiceFolders: - """Tests for WorkflowImportService folder functionality.""" - - @pytest.mark.anyio - async def test_ensure_folder_exists_creates_nested_folders( - self, workflow_import_service - ): - """Test that _ensure_folder_exists creates nested folder structure.""" - # Mock folder service - mock_folder_service = AsyncMock() - workflow_import_service.folder_service = mock_folder_service - - # Mock that no folders exist initially - mock_folder_service.get_folder_by_path.side_effect = [ - None, # /security/ doesn't exist - None, # /security/detections/ doesn't exist - Mock(id=uuid.uuid4()), # final folder exists after creation - ] - - # Mock folder creation - mock_security_folder = Mock(id=uuid.uuid4()) - mock_detections_folder = Mock(id=uuid.uuid4()) - mock_folder_service.create_folder.side_effect = [ - mock_security_folder, - mock_detections_folder, - ] - - await workflow_import_service._ensure_folder_exists("/security/detections/") - - # Verify folders were created in correct order - assert mock_folder_service.create_folder.call_count == 2 - - # First call creates 'security' folder at root - first_call = mock_folder_service.create_folder.call_args_list[0] - assert first_call.kwargs["name"] == "security" - assert first_call.kwargs["parent_path"] == "/" - - # Second call creates 'detections' folder under /security/ - second_call = mock_folder_service.create_folder.call_args_list[1] - assert second_call.kwargs["name"] == "detections" - assert second_call.kwargs["parent_path"] == "/security/" - - # Final get_folder_by_path call to return created folder - mock_folder_service.get_folder_by_path.assert_called_with( - "/security/detections/" - ) - - @pytest.mark.anyio - async def test_ensure_folder_exists_with_existing_folders( - self, workflow_import_service - ): - """Test that _ensure_folder_exists handles existing folders.""" - # Mock folder service - mock_folder_service = AsyncMock() - workflow_import_service.folder_service = mock_folder_service - - # Mock that security folder exists but detections doesn't - mock_security_folder = Mock(id=uuid.uuid4()) - mock_detections_folder = Mock(id=uuid.uuid4()) - - mock_folder_service.get_folder_by_path.side_effect = [ - mock_security_folder, # /security/ exists - None, # /security/detections/ doesn't exist - mock_detections_folder, # final folder exists after creation - ] - - mock_folder_service.create_folder.return_value = mock_detections_folder - - await workflow_import_service._ensure_folder_exists("/security/detections/") - - # Verify only detections folder was created - assert mock_folder_service.create_folder.call_count == 1 - call_args = mock_folder_service.create_folder.call_args - assert call_args.kwargs["name"] == "detections" - assert call_args.kwargs["parent_path"] == "/security/" - - @pytest.mark.anyio - async def test_create_new_workflow_with_folder_path( - self, workflow_import_service, sample_remote_workflow_with_folder - ): - """Test creating a new workflow with folder_path sets folder_id.""" - # Mock dependencies - mock_wf_mgmt = AsyncMock() - mock_workflow = Mock() - mock_workflow.id = uuid.uuid4() - mock_wf_mgmt.create_db_workflow_from_dsl.return_value = mock_workflow - workflow_import_service.wf_mgmt = mock_wf_mgmt - - mock_defn_service = AsyncMock() - mock_defn = Mock(version=1) - mock_defn_service.create_workflow_definition.return_value = mock_defn - - # Mock session and flush - workflow_import_service.session.flush = AsyncMock() - - # Mock folder creation - test_folder_id = uuid.uuid4() - workflow_import_service._ensure_folder_exists = AsyncMock( - return_value=test_folder_id - ) - workflow_import_service._create_schedules = AsyncMock() - workflow_import_service._update_webhook = AsyncMock() - workflow_import_service._update_case_trigger = AsyncMock() - workflow_import_service._create_tags = AsyncMock() - - with patch( - "tracecat.workflow.store.import_service.WorkflowDefinitionsService", - return_value=mock_defn_service, - ): - await workflow_import_service._create_new_workflow( - sample_remote_workflow_with_folder - ) - - # Verify folder was created and workflow.folder_id was set - workflow_import_service._ensure_folder_exists.assert_called_once_with( - "/security/detections/" - ) - assert mock_workflow.folder_id == test_folder_id - - @pytest.mark.anyio - async def test_create_new_workflow_without_folder_path( - self, workflow_import_service, sample_remote_workflow - ): - """Test creating a new workflow without folder_path leaves folder_id as None.""" - # Mock dependencies - mock_wf_mgmt = AsyncMock() - mock_workflow = Mock() - mock_workflow.id = uuid.uuid4() - mock_wf_mgmt.create_db_workflow_from_dsl.return_value = mock_workflow - workflow_import_service.wf_mgmt = mock_wf_mgmt - - mock_defn_service = AsyncMock() - mock_defn = Mock(version=1) - mock_defn_service.create_workflow_definition.return_value = mock_defn - - # Mock session and flush - workflow_import_service.session.flush = AsyncMock() - - workflow_import_service._ensure_folder_exists = AsyncMock() - workflow_import_service._create_schedules = AsyncMock() - workflow_import_service._update_webhook = AsyncMock() - workflow_import_service._update_case_trigger = AsyncMock() - workflow_import_service._create_tags = AsyncMock() - - with patch( - "tracecat.workflow.store.import_service.WorkflowDefinitionsService", - return_value=mock_defn_service, - ): - await workflow_import_service._create_new_workflow(sample_remote_workflow) - - # Verify folder creation was not called and folder_id was not set - workflow_import_service._ensure_folder_exists.assert_not_called() - # folder_id should not be set (remains None by default) diff --git a/tests/unit/test_workspace_sync.py b/tests/unit/test_workspace_sync.py new file mode 100644 index 0000000000..1610ac160c --- /dev/null +++ b/tests/unit/test_workspace_sync.py @@ -0,0 +1,391 @@ +"""Tests for workspace Git sync projection primitives.""" + +from __future__ import annotations + +import uuid +from types import SimpleNamespace +from typing import cast +from unittest.mock import AsyncMock, patch + +import pytest +import yaml +from pydantic import BaseModel +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from tracecat.auth.types import Role +from tracecat.cases.enums import CaseEventType +from tracecat.db.models import ( + Workflow, + WorkflowDefinition, + WorkspaceSyncChangeSet, + WorkspaceSyncChangeSetItem, + WorkspaceSyncResourceMapping, +) +from tracecat.dsl.common import DSLEntrypoint, DSLInput +from tracecat.dsl.schemas import ActionStatement +from tracecat.identifiers.workflow import WorkflowUUID +from tracecat.registry.lock.types import RegistryLock +from tracecat.sync import CommitInfo, PushStatus +from tracecat.workflow.management.definitions import WorkflowDefinitionsService +from tracecat.workflow.management.management import WorkflowsManagementService +from tracecat.workflow.store.schemas import RemoteWorkflowDefinition +from tracecat.workspace_sync.enums import SyncResourceType +from tracecat.workspace_sync.git import GitTreeSnapshot +from tracecat.workspace_sync.schemas import ( + ChangeSetCreate, + ChangeSetExport, + ResourceRef, + WorkspaceManifest, +) +from tracecat.workspace_sync.serialization import canonical_json_text, stable_hash +from tracecat.workspace_sync.service import WorkspaceGitSyncService +from tracecat.workspace_sync.workflow import ( + parse_workflow_spec, + serialize_workflow_spec, + workflow_spec_from_orm, +) + + +@pytest.fixture +def sample_dsl() -> DSLInput: + return DSLInput( + title="Detect Okta Risk", + description="Detects suspicious Okta activity", + entrypoint=DSLEntrypoint(ref="start", expects={}), + actions=[ + ActionStatement( + ref="start", + action="core.transform.passthrough", + args={"value": "test"}, + ) + ], + ) + + +def test_manifest_serializes_as_canonical_json() -> None: + text = canonical_json_text(WorkspaceManifest()) + + assert ( + text + == '{\n "resources": {\n "workflows": "workflows/"\n },\n "version": 1\n}\n' + ) + + +def test_stable_hash_ignores_model_defaults() -> None: + class HashModel(BaseModel): + name: str + future_default: str = "default" + + model_hash = stable_hash(HashModel(name="workflow")) + + assert model_hash.startswith("v1:") + assert model_hash == stable_hash({"name": "workflow"}) + + +def test_workflow_spec_does_not_serialize_local_uuid(sample_dsl: DSLInput) -> None: + local_id = uuid.uuid4() + workflow = SimpleNamespace( + id=local_id, + alias="okta-risk", + tags=[], + folder=None, + schedules=[], + webhook=SimpleNamespace( + methods=["POST"], status="online", include_headers=False + ), + case_trigger=SimpleNamespace( + status="offline", + event_types=[], + tag_filters=[], + ), + ) + + spec = workflow_spec_from_orm( + cast(Workflow, workflow), + dsl=sample_dsl, + source_id="detect-okta-risk", + ) + content = serialize_workflow_spec(spec) + + assert "detect-okta-risk" in content + assert str(local_id) not in content + assert "wf_" not in content + + +def test_workflow_spec_includes_configured_case_trigger(sample_dsl: DSLInput) -> None: + workflow = SimpleNamespace( + id=uuid.uuid4(), + alias="okta-risk", + tags=[], + folder=None, + schedules=[], + webhook=SimpleNamespace( + methods=["POST"], status="online", include_headers=False + ), + case_trigger=SimpleNamespace( + status="online", + event_types=[CaseEventType.CASE_CREATED.value], + tag_filters=["phishing"], + ), + ) + + spec = workflow_spec_from_orm( + cast(Workflow, workflow), + dsl=sample_dsl, + source_id="detect-okta-risk", + ) + + assert spec.case_trigger is not None + assert spec.case_trigger.status == "online" + assert spec.case_trigger.event_types == [CaseEventType.CASE_CREATED] + assert spec.case_trigger.tag_filters == ["phishing"] + + +def test_legacy_workflow_file_dual_reads_to_source_id( + sample_dsl: DSLInput, +) -> None: + legacy = RemoteWorkflowDefinition( + id="wf_0000000000000000000001", + alias="legacy-workflow", + definition=sample_dsl, + ) + content = yaml.safe_dump( + legacy.model_dump(mode="json", exclude_none=True), + sort_keys=False, + ) + + spec, diagnostic = parse_workflow_spec( + "workflows/legacy-source/definition.yml", + content, + ) + + assert diagnostic is None + assert spec is not None + assert spec.id == "legacy-source" + assert spec.alias == "legacy-workflow" + + +@pytest.mark.anyio +@pytest.mark.usefixtures("db") +async def test_resource_mapping_stores_source_id_to_local_uuid( + session: AsyncSession, + svc_role: Role, +) -> None: + service = WorkspaceGitSyncService(session=session, role=svc_role) + local_id = uuid.uuid4() + + mapping = await service._ensure_resource_mapping( + resource_type=SyncResourceType.WORKFLOW.value, + local_id=local_id, + preferred_source_id="detect-okta-risk", + source_path="workflows/detect-okta-risk/definition.yml", + create=True, + reserved_source_ids=set(), + ) + + assert mapping is not None + assert mapping.source_id == "detect-okta-risk" + assert mapping.local_id == local_id + assert mapping.workspace_id == svc_role.workspace_id + + +class FakeGitHubSyncTransport: + files: dict[str, str] = {} + written_files: dict[str, str] | None = None + written_branch: str | None = None + written_create_pr: bool | None = None + + def __init__(self, *args, **kwargs) -> None: + pass + + async def read_files(self, *args, **kwargs) -> GitTreeSnapshot: + return GitTreeSnapshot( + commit_sha="a" * 40, + tree_sha="b" * 40, + files=self.files, + ) + + async def write_files( + self, + *, + files: dict[str, str], + branch: str, + create_pr: bool, + **kwargs, + ) -> CommitInfo: + self.__class__.written_files = files + self.__class__.written_branch = branch + self.__class__.written_create_pr = create_pr + return CommitInfo( + status=PushStatus.COMMITTED, + sha="c" * 40, + ref=branch, + base_ref=kwargs.get("pr_base_branch") or "main", + pr_url="https://github.com/test-org/test-repo/pull/1" + if create_pr + else None, + pr_number=1 if create_pr else None, + pr_reused=False, + message="Committed workspace sync changes.", + ) + + +async def _create_local_workflow( + *, + session: AsyncSession, + role: Role, + dsl: DSLInput, + alias: str, +) -> Workflow: + with patch( + "tracecat.workflow.management.management.RegistryLockService.resolve_lock_with_bindings", + new=AsyncMock( + return_value=RegistryLock( + origins={"tracecat_registry": "test"}, + actions={"core.transform.passthrough": "tracecat_registry"}, + ) + ), + ): + workflow = await WorkflowsManagementService( + session=session, + role=role, + ).create_db_workflow_from_dsl( + dsl, + workflow_alias=alias, + commit=False, + ) + await WorkflowDefinitionsService( + session=session, + role=role, + ).create_workflow_definition( + WorkflowUUID.new(workflow.id), + dsl, + alias=alias, + commit=False, + ) + await session.commit() + return workflow + + +@pytest.mark.anyio +@pytest.mark.usefixtures("db") +async def test_status_pending_changeset_and_export_with_mocked_github( + session: AsyncSession, + svc_role: Role, + svc_workspace, + sample_dsl: DSLInput, +) -> None: + svc_workspace.settings = { + "git_repo_url": "git+ssh://git@github.com/test-org/test-repo.git" + } + session.add(svc_workspace) + await _create_local_workflow( + session=session, + role=svc_role, + dsl=sample_dsl, + alias="detect-okta-risk", + ) + + FakeGitHubSyncTransport.files = {} + FakeGitHubSyncTransport.written_files = None + FakeGitHubSyncTransport.written_branch = None + FakeGitHubSyncTransport.written_create_pr = None + + with patch( + "tracecat.workspace_sync.service.WorkspaceGitHubSyncService", + FakeGitHubSyncTransport, + ): + service = WorkspaceGitSyncService(session=session, role=svc_role) + + status = await service.get_status() + assert status.status == "never_synced" + assert status.pending_change_count == 1 + assert status.remote_commit_sha == "a" * 40 + + pending = await service.list_pending_changes() + assert len(pending.changes) == 1 + pending_change = pending.changes[0] + assert pending_change.operation == "create" + assert pending_change.source_id == "detect-okta-risk" + assert pending_change.title == "Detect Okta Risk" + mapping_before_changeset = await session.scalar( + select(WorkspaceSyncResourceMapping).where( + WorkspaceSyncResourceMapping.workspace_id == svc_role.workspace_id + ) + ) + assert mapping_before_changeset is None + + changeset = await service.create_changeset( + params=ChangeSetCreate( + title="Export workflow", + resources=[ + ResourceRef( + resource_type=pending_change.resource_type, + source_id=pending_change.source_id, + source_path=pending_change.source_path, + ) + ], + ) + ) + assert changeset.status == "validated" + assert changeset.selected_paths == [ + "tracecat.json", + "workflows/detect-okta-risk/definition.yml", + ] + assert changeset.selected_resources[0]["local_id"] is not None + changeset_row = await session.scalar( + select(WorkspaceSyncChangeSet).where( + WorkspaceSyncChangeSet.id == changeset.id + ) + ) + assert changeset_row is not None + assert set(changeset_row.rendered_files) == { + "tracecat.json", + "workflows/detect-okta-risk/definition.yml", + } + + changeset_item = await session.scalar( + select(WorkspaceSyncChangeSetItem).where( + WorkspaceSyncChangeSetItem.changeset_id == changeset.id + ) + ) + assert changeset_item is not None + assert changeset_item.operation == "create" + assert changeset_item.local_id is not None + + definition = await session.scalar( + select(WorkflowDefinition).where( + WorkflowDefinition.workspace_id == svc_role.workspace_id, + WorkflowDefinition.alias == "detect-okta-risk", + ) + ) + assert definition is not None + definition.content = {**definition.content, "title": "Mutated Okta Risk"} + session.add(definition) + await session.commit() + + result = await service.export_changeset( + changeset_id=changeset.id, + params=ChangeSetExport( + message="Export workflow", + branch="sync/detect-okta-risk", + create_pr=True, + ), + ) + + assert result.commit.status == PushStatus.COMMITTED + assert result.commit.pr_number == 1 + assert FakeGitHubSyncTransport.written_branch == "sync/detect-okta-risk" + assert FakeGitHubSyncTransport.written_create_pr is True + assert FakeGitHubSyncTransport.written_files is not None + assert set(FakeGitHubSyncTransport.written_files) == { + "tracecat.json", + "workflows/detect-okta-risk/definition.yml", + } + written_workflow = yaml.safe_load( + FakeGitHubSyncTransport.written_files[ + "workflows/detect-okta-risk/definition.yml" + ] + ) + assert written_workflow["definition"]["title"] == "Detect Okta Risk" diff --git a/tracecat/db/models.py b/tracecat/db/models.py index daed603318..58a031c4ff 100644 --- a/tracecat/db/models.py +++ b/tracecat/db/models.py @@ -513,6 +513,266 @@ class Workspace(OrganizationModel): ) +class WorkspaceSyncState(RecordModel): + """Workspace-level Git sync metadata.""" + + __tablename__ = "workspace_sync_state" + __table_args__ = ( + UniqueConstraint( + "workspace_id", + "provider", + "repo_url", + "target_ref", + name="uq_workspace_sync_state_workspace_provider_repo_ref", + ), + ) + + id: Mapped[uuid.UUID] = mapped_column( + UUID, default=uuid.uuid4, nullable=False, unique=True, index=True + ) + workspace_id: Mapped[WorkspaceID] = mapped_column( + UUID, + ForeignKey("workspace.id", ondelete="CASCADE"), + nullable=False, + index=True, + ) + provider: Mapped[str] = mapped_column( + String(32), default="git", server_default=text("'git'"), nullable=False + ) + repo_url: Mapped[str] = mapped_column(String, nullable=False) + target_ref: Mapped[str] = mapped_column( + String, default="main", server_default=text("'main'"), nullable=False + ) + base_commit_sha: Mapped[str | None] = mapped_column(String, nullable=True) + base_tree_sha: Mapped[str | None] = mapped_column(String, nullable=True) + base_spec_hash: Mapped[str | None] = mapped_column(String, nullable=True) + last_remote_commit_sha: Mapped[str | None] = mapped_column(String, nullable=True) + last_remote_tree_sha: Mapped[str | None] = mapped_column(String, nullable=True) + status: Mapped[str] = mapped_column( + String(32), + default="never_synced", + server_default=text("'never_synced'"), + nullable=False, + ) + last_direction: Mapped[str | None] = mapped_column(String(16), nullable=True) + last_error: Mapped[dict[str, Any] | None] = mapped_column(JSONB, nullable=True) + last_synced_at: Mapped[datetime | None] = mapped_column( + TIMESTAMP(timezone=True), nullable=True + ) + + +class WorkspaceSyncResourceMapping(RecordModel): + """Maps stable sync source identities to workspace-local resource UUIDs.""" + + __tablename__ = "workspace_sync_resource_mapping" + __table_args__ = ( + UniqueConstraint( + "workspace_id", + "provider", + "resource_type", + "source_id", + name="uq_workspace_sync_mapping_source", + ), + UniqueConstraint( + "workspace_id", + "provider", + "resource_type", + "local_id", + name="uq_workspace_sync_mapping_local", + ), + ) + + id: Mapped[uuid.UUID] = mapped_column( + UUID, default=uuid.uuid4, nullable=False, unique=True, index=True + ) + workspace_id: Mapped[WorkspaceID] = mapped_column( + UUID, + ForeignKey("workspace.id", ondelete="CASCADE"), + nullable=False, + index=True, + ) + provider: Mapped[str] = mapped_column( + String(32), default="git", server_default=text("'git'"), nullable=False + ) + resource_type: Mapped[str] = mapped_column(String(64), nullable=False) + source_id: Mapped[str] = mapped_column(String, nullable=False) + source_path: Mapped[str | None] = mapped_column(String, nullable=True) + local_id: Mapped[uuid.UUID] = mapped_column(UUID, nullable=False) + last_synced_commit_sha: Mapped[str | None] = mapped_column(String, nullable=True) + last_synced_spec_hash: Mapped[str | None] = mapped_column(String, nullable=True) + sync_status: Mapped[str] = mapped_column( + String(32), + default="untracked", + server_default=text("'untracked'"), + nullable=False, + ) + + +class WorkspaceSyncEvent(RecordModel): + """Append-only provenance for Git-relevant workspace mutations.""" + + __tablename__ = "workspace_sync_event" + + id: Mapped[uuid.UUID] = mapped_column( + UUID, default=uuid.uuid4, nullable=False, unique=True, index=True + ) + workspace_id: Mapped[WorkspaceID] = mapped_column( + UUID, + ForeignKey("workspace.id", ondelete="CASCADE"), + nullable=False, + index=True, + ) + provider: Mapped[str] = mapped_column( + String(32), default="git", server_default=text("'git'"), nullable=False + ) + resource_type: Mapped[str] = mapped_column(String(64), nullable=False) + source_id: Mapped[str | None] = mapped_column(String, nullable=True) + local_id: Mapped[uuid.UUID | None] = mapped_column(UUID, nullable=True) + operation: Mapped[str] = mapped_column(String(32), nullable=False) + actor_id: Mapped[uuid.UUID | None] = mapped_column(UUID, nullable=True) + base_commit_sha: Mapped[str | None] = mapped_column(String, nullable=True) + before_spec_hash: Mapped[str | None] = mapped_column(String, nullable=True) + after_spec_hash: Mapped[str | None] = mapped_column(String, nullable=True) + affected_paths: Mapped[list[str]] = mapped_column( + JSONB, default=list, server_default=text("'[]'::jsonb"), nullable=False + ) + metadata_: Mapped[dict[str, Any]] = mapped_column( + "metadata", + JSONB, + default=dict, + server_default=text("'{}'::jsonb"), + nullable=False, + ) + superseded_by: Mapped[uuid.UUID | None] = mapped_column(UUID, nullable=True) + changeset_id: Mapped[uuid.UUID | None] = mapped_column(UUID, nullable=True) + + +class WorkspaceSyncChangeSet(RecordModel): + """Tracecat-side review and export unit for syncable workspace changes.""" + + __tablename__ = "workspace_sync_changeset" + + id: Mapped[uuid.UUID] = mapped_column( + UUID, default=uuid.uuid4, nullable=False, unique=True, index=True + ) + workspace_id: Mapped[WorkspaceID] = mapped_column( + UUID, + ForeignKey("workspace.id", ondelete="CASCADE"), + nullable=False, + index=True, + ) + provider: Mapped[str] = mapped_column( + String(32), default="git", server_default=text("'git'"), nullable=False + ) + title: Mapped[str] = mapped_column(String, nullable=False) + description: Mapped[str | None] = mapped_column(Text, nullable=True) + base_commit_sha: Mapped[str | None] = mapped_column(String, nullable=True) + base_spec_hash: Mapped[str | None] = mapped_column(String, nullable=True) + selected_resources: Mapped[list[dict[str, Any]]] = mapped_column( + JSONB, default=list, server_default=text("'[]'::jsonb"), nullable=False + ) + selected_paths: Mapped[list[str]] = mapped_column( + JSONB, default=list, server_default=text("'[]'::jsonb"), nullable=False + ) + rendered_files: Mapped[dict[str, str]] = mapped_column( + JSONB, default=dict, server_default=text("'{}'::jsonb"), nullable=False + ) + validation_status: Mapped[str] = mapped_column( + String(32), + default="pending", + server_default=text("'pending'"), + nullable=False, + ) + validation_result: Mapped[dict[str, Any]] = mapped_column( + JSONB, default=dict, server_default=text("'{}'::jsonb"), nullable=False + ) + status: Mapped[str] = mapped_column( + String(32), + default="open", + server_default=text("'open'"), + nullable=False, + ) + created_by: Mapped[uuid.UUID | None] = mapped_column(UUID, nullable=True) + + +class WorkspaceSyncChangeSetItem(RecordModel): + """A resource/path selected into a workspace sync ChangeSet.""" + + __tablename__ = "workspace_sync_changeset_item" + __table_args__ = ( + UniqueConstraint( + "changeset_id", + "resource_type", + "source_id", + name="uq_workspace_sync_changeset_item_resource", + ), + ) + + id: Mapped[uuid.UUID] = mapped_column( + UUID, default=uuid.uuid4, nullable=False, unique=True, index=True + ) + workspace_id: Mapped[WorkspaceID] = mapped_column( + UUID, + ForeignKey("workspace.id", ondelete="CASCADE"), + nullable=False, + index=True, + ) + changeset_id: Mapped[uuid.UUID] = mapped_column( + UUID, + ForeignKey("workspace_sync_changeset.id", ondelete="CASCADE"), + nullable=False, + index=True, + ) + resource_type: Mapped[str] = mapped_column(String(64), nullable=False) + source_id: Mapped[str] = mapped_column(String, nullable=False) + source_path: Mapped[str | None] = mapped_column(String, nullable=True) + local_id: Mapped[uuid.UUID | None] = mapped_column(UUID, nullable=True) + operation: Mapped[str] = mapped_column(String(32), nullable=False) + spec_hash: Mapped[str | None] = mapped_column(String, nullable=True) + dependencies: Mapped[list[dict[str, Any]]] = mapped_column( + JSONB, default=list, server_default=text("'[]'::jsonb"), nullable=False + ) + + +class WorkspaceSyncMaterialization(RecordModel): + """Git branch, commit, and PR output for a materialized ChangeSet.""" + + __tablename__ = "workspace_sync_materialization" + + id: Mapped[uuid.UUID] = mapped_column( + UUID, default=uuid.uuid4, nullable=False, unique=True, index=True + ) + workspace_id: Mapped[WorkspaceID] = mapped_column( + UUID, + ForeignKey("workspace.id", ondelete="CASCADE"), + nullable=False, + index=True, + ) + changeset_id: Mapped[uuid.UUID] = mapped_column( + UUID, + ForeignKey("workspace_sync_changeset.id", ondelete="CASCADE"), + nullable=False, + index=True, + ) + provider: Mapped[str] = mapped_column( + String(32), default="git", server_default=text("'git'"), nullable=False + ) + branch: Mapped[str] = mapped_column(String, nullable=False) + base_ref: Mapped[str | None] = mapped_column(String, nullable=True) + pr_number: Mapped[int | None] = mapped_column(Integer, nullable=True) + pr_url: Mapped[str | None] = mapped_column(String, nullable=True) + commit_shas: Mapped[list[str]] = mapped_column( + JSONB, default=list, server_default=text("'[]'::jsonb"), nullable=False + ) + status: Mapped[str] = mapped_column( + String(32), + default="pending", + server_default=text("'pending'"), + nullable=False, + ) + error: Mapped[dict[str, Any] | None] = mapped_column(JSONB, nullable=True) + + class User(SQLAlchemyBaseUserTableUUID, Base): __tablename__ = "user" diff --git a/tracecat/db/tenant_rls.py b/tracecat/db/tenant_rls.py index 21d76636aa..b1ae913ce9 100644 --- a/tracecat/db/tenant_rls.py +++ b/tracecat/db/tenant_rls.py @@ -84,6 +84,12 @@ "skill_version_file", "agent_preset_skill", "agent_preset_version_skill", + "workspace_sync_state", + "workspace_sync_resource_mapping", + "workspace_sync_event", + "workspace_sync_changeset", + "workspace_sync_changeset_item", + "workspace_sync_materialization", ) POST_RLS_ORG_SCOPED_TABLES = ( diff --git a/tracecat/workflow/store/router.py b/tracecat/workflow/store/router.py index 1d4bd33c51..1dfe3e4184 100644 --- a/tracecat/workflow/store/router.py +++ b/tracecat/workflow/store/router.py @@ -1,3 +1,5 @@ +import uuid + from fastapi import APIRouter, HTTPException, Query, status from tracecat import config @@ -7,6 +9,7 @@ from tracecat.dsl.common import DSLInput from tracecat.exceptions import ( TracecatCredentialsNotFoundError, + TracecatNotFoundError, TracecatSettingsError, TracecatValidationError, ) @@ -21,9 +24,18 @@ WorkflowDslPublish, WorkflowDslPublishResult, WorkflowSyncPullRequest, + validate_short_branch_name, ) from tracecat.workflow.store.service import WorkflowStoreService -from tracecat.workflow.store.sync import WorkflowSyncService +from tracecat.workspace_sync.schemas import ( + ChangeSetCreate, + ChangeSetExport, + ChangeSetRead, + WorkspaceSyncExportResult, + WorkspaceSyncPendingChanges, + WorkspaceSyncStatus, +) +from tracecat.workspace_sync.service import WorkspaceGitSyncService from tracecat.workspaces.service import WorkspaceService router = APIRouter(prefix="/workflows", tags=["workflows"]) @@ -136,8 +148,8 @@ async def list_workflow_commits( # Parse and validate Git URL git_url = parse_git_url(repository_url) - # Initialize workflow sync service - sync_service = WorkflowSyncService(session=session, role=role) + # Initialize workspace sync service + sync_service = WorkspaceGitSyncService(session=session, role=role) # Fetch commits using GitHub App API commits = await sync_service.list_commits( @@ -215,7 +227,7 @@ async def list_workflow_branches( ) git_url = parse_git_url(repository_url) - sync_service = WorkflowSyncService(session=session, role=role) + sync_service = WorkspaceGitSyncService(session=session, role=role) branches = await sync_service.list_branches(url=git_url, limit=limit) return branches except HTTPException: @@ -246,6 +258,165 @@ async def list_workflow_branches( ) from e +@router.get("/sync/status", response_model=WorkspaceSyncStatus) +@require_scope("workflow:sync") +async def get_workspace_sync_status( + role: WorkspaceActorRouteRole, + session: AsyncDBSession, +) -> WorkspaceSyncStatus: + """Get workspace-level Git sync status for the configured repository.""" + if not role.workspace_id: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Workspace ID is required", + ) + sync_service = WorkspaceGitSyncService(session=session, role=role) + try: + return await sync_service.get_status() + except TracecatSettingsError as e: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=str(e), + ) from e + + +@router.get("/sync/pending", response_model=WorkspaceSyncPendingChanges) +@require_scope("workflow:sync") +async def list_workspace_sync_pending_changes( + role: WorkspaceActorRouteRole, + session: AsyncDBSession, +) -> WorkspaceSyncPendingChanges: + """List local syncable workspace changes pending Git export.""" + if not role.workspace_id: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Workspace ID is required", + ) + sync_service = WorkspaceGitSyncService(session=session, role=role) + try: + return await sync_service.list_pending_changes() + except TracecatSettingsError as e: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=str(e), + ) from e + + +@router.get("/sync/changesets", response_model=list[ChangeSetRead]) +@require_scope("workflow:sync") +async def list_workspace_sync_changesets( + role: WorkspaceActorRouteRole, + session: AsyncDBSession, + limit: int = Query(default=50, ge=1, le=100), +) -> list[ChangeSetRead]: + """List workspace sync ChangeSets.""" + if not role.workspace_id: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Workspace ID is required", + ) + sync_service = WorkspaceGitSyncService(session=session, role=role) + return await sync_service.list_changesets(limit=limit) + + +@router.post( + "/sync/changesets", + response_model=ChangeSetRead, + status_code=status.HTTP_201_CREATED, +) +@require_scope("workflow:update", "workflow:sync") +async def create_workspace_sync_changeset( + role: WorkspaceActorRouteRole, + session: AsyncDBSession, + params: ChangeSetCreate, +) -> ChangeSetRead: + """Create a workspace sync ChangeSet from selected pending resources.""" + if not role.workspace_id: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Workspace ID is required", + ) + sync_service = WorkspaceGitSyncService(session=session, role=role) + try: + return await sync_service.create_changeset(params) + except (TracecatSettingsError, TracecatValidationError) as e: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=str(e), + ) from e + + +@router.get("/sync/changesets/{changeset_id}", response_model=ChangeSetRead) +@require_scope("workflow:sync") +async def get_workspace_sync_changeset( + role: WorkspaceActorRouteRole, + session: AsyncDBSession, + changeset_id: uuid.UUID, +) -> ChangeSetRead: + """Get a workspace sync ChangeSet.""" + if not role.workspace_id: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Workspace ID is required", + ) + sync_service = WorkspaceGitSyncService(session=session, role=role) + try: + return await sync_service.get_changeset(changeset_id) + except TracecatNotFoundError as e: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=str(e), + ) from e + + +@router.post( + "/sync/changesets/{changeset_id}/export", + response_model=WorkspaceSyncExportResult, +) +@require_scope("workflow:update", "workflow:sync") +async def export_workspace_sync_changeset( + role: WorkspaceActorRouteRole, + session: AsyncDBSession, + changeset_id: uuid.UUID, + params: ChangeSetExport, +) -> WorkspaceSyncExportResult: + """Export a workspace sync ChangeSet to a Git branch and optional PR.""" + if not role.workspace_id: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Workspace ID is required", + ) + try: + validate_short_branch_name(params.branch, field_name="branch") + if params.pr_base_branch: + validate_short_branch_name( + params.pr_base_branch, + field_name="pr_base_branch", + ) + except ValueError as e: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=str(e), + ) from e + + sync_service = WorkspaceGitSyncService(session=session, role=role) + try: + return await sync_service.export_changeset( + changeset_id=changeset_id, + params=params, + ) + except TracecatNotFoundError as e: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=str(e), + ) from e + except (TracecatSettingsError, TracecatValidationError, GitHubAppError) as e: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=str(e), + ) from e + + @router.post("/sync/pull", response_model=PullResult) @require_scope("workflow:update", "workflow:sync") async def pull_workflows( @@ -294,8 +465,8 @@ async def pull_workflows( dry_run=params.dry_run, ) - # Initialize workflow sync service - sync_service = WorkflowSyncService(session=session, role=role) + # Initialize workspace sync service + sync_service = WorkspaceGitSyncService(session=session, role=role) # Perform the pull operation return await sync_service.pull(url=git_url, options=pull_options) diff --git a/tracecat/workflow/store/service.py b/tracecat/workflow/store/service.py index 30be8bd343..af72c6f702 100644 --- a/tracecat/workflow/store/service.py +++ b/tracecat/workflow/store/service.py @@ -1,30 +1,20 @@ +from datetime import UTC, datetime from pathlib import Path -from typing import cast from tracecat.authz.controls import require_scope -from tracecat.cases.enums import CaseEventType from tracecat.db.models import Workflow from tracecat.dsl.common import DSLInput -from tracecat.exceptions import TracecatSettingsError, TracecatValidationError -from tracecat.git.utils import parse_git_url +from tracecat.exceptions import TracecatValidationError from tracecat.identifiers.workflow import WorkflowUUID from tracecat.logger import logger from tracecat.service import BaseWorkspaceService -from tracecat.sync import Author, PushObject, PushOptions, PushStatus -from tracecat.workflow.case_triggers.schemas import is_case_trigger_configured +from tracecat.sync import Author, PushOptions from tracecat.workflow.store.schemas import ( - RemoteCaseTrigger, - RemoteWebhook, - RemoteWorkflowDefinition, - RemoteWorkflowSchedule, - RemoteWorkflowTag, - Status, WorkflowDslPublish, WorkflowDslPublishResult, validate_short_branch_name, ) -from tracecat.workflow.store.sync import WorkflowSyncService -from tracecat.workspaces.service import WorkspaceService +from tracecat.workspace_sync.service import WorkspaceGitSyncService class WorkflowStoreService(BaseWorkspaceService): @@ -46,97 +36,17 @@ async def publish_workflow_dsl( f"Workflow ID mismatch: provided {workflow_id} but workflow object has ID {workflow.id}" ) - # Get workspace settings for git configuration - - workspace_service = WorkspaceService(session=self.session, role=self.role) - workspace = await workspace_service.get_workspace(self.workspace_id) - - if not workspace: - raise TracecatValidationError("Workspace not found") - - # Extract git configuration from workspace settings - git_repo_url = workspace.settings.get("git_repo_url") - if not git_repo_url: - raise TracecatSettingsError( - "Git repository URL not configured for this workspace. " - "Please contact your administrator to configure it." - ) - logger.info( "Publishing workflow to store", workflow_title=dsl.title, - repo_url=git_repo_url, workspace_id=self.workspace_id, ) - # Parse the Git URL using workspace settings - try: - git_url = parse_git_url(git_repo_url, allowed_domains={"github.com"}) - except ValueError as e: - raise TracecatSettingsError( - f"Invalid Git repository URL configured for this workspace: {e}. " - "Please contact your administrator to fix the configuration." - ) from e - # Note: We could add ref support later if needed via params or workspace settings - - stable_path = get_definition_path(workflow_id) - webhook = workflow.webhook - - await self.session.refresh(workflow, ["tags", "folder", "case_trigger"]) - - # Get folder path if workflow is in a folder - folder_path = None - if workflow.folder: - folder_path = workflow.folder.path - - # Create PushObject with data and stable path - case_trigger = None - if workflow.case_trigger and is_case_trigger_configured( - status=workflow.case_trigger.status, - event_types=workflow.case_trigger.event_types, - tag_filters=workflow.case_trigger.tag_filters, - ): - case_trigger = RemoteCaseTrigger( - status=cast(Status, workflow.case_trigger.status), - event_types=[ - CaseEventType(event_type) - for event_type in workflow.case_trigger.event_types - ], - tag_filters=workflow.case_trigger.tag_filters, - ) - - defn = RemoteWorkflowDefinition( - id=workflow_id.short(), - alias=workflow.alias, - folder_path=folder_path, - tags=[RemoteWorkflowTag(name=t.name) for t in workflow.tags], - # Convert Schedule ORM objects to RemoteWorkflowSchedule, handling type conversions and missing fields. - schedules=[ - RemoteWorkflowSchedule( - status=cast(Status, s.status), - cron=s.cron, - every=s.every, - offset=s.offset, - start_at=s.start_at, - end_at=s.end_at, - timeout=s.timeout, - ) - for s in (workflow.schedules or []) - ], - webhook=RemoteWebhook( - methods=webhook.methods, - status=cast(Status, webhook.status), - include_headers=webhook.include_headers, - ), - case_trigger=case_trigger, - definition=dsl, - ) - push_obj = PushObject(data=defn, path=stable_path) - author = Author(name="Tracecat", email="noreply@tracecat.com") publish_message = params.message or f"Publish workflow: {dsl.title}" - validated_branch: str | None = None + validated_branch: str validated_pr_base_branch: str | None = None + create_pr = params.create_pr if params.branch is not None: try: @@ -151,65 +61,45 @@ async def publish_workflow_dsl( ) except ValueError as e: raise TracecatValidationError(str(e)) from e - - if params.branch is None: + else: logger.warning( "workflow_publish_legacy_mode_used", workflow_id=str(workflow_id), workspace_id=str(self.workspace_id), ) - push_options = PushOptions( - message=publish_message, - author=author, - create_pr=True, - ) - else: - push_options = PushOptions( - message=publish_message, - author=author, - create_pr=params.create_pr, - branch=validated_branch, - pr_base_branch=validated_pr_base_branch, + validated_branch = validate_short_branch_name( + f"tracecat-sync-{datetime.now(UTC).strftime('%Y%m%d-%H%M%S')}", + field_name="branch", ) + create_pr = True + + push_options = PushOptions( + message=publish_message, + author=author, + create_pr=create_pr, + branch=validated_branch, + pr_base_branch=validated_pr_base_branch, + ) - # Use WorkflowSyncService to push the workflow with stable path - sync_service = WorkflowSyncService(session=self.session, role=self.role) - commit_info = await sync_service.push( - objects=[push_obj], - url=git_url, + sync_service = WorkspaceGitSyncService(session=self.session, role=self.role) + result = await sync_service.export_workflow_publish_result( + workflow=workflow, + dsl=dsl, options=push_options, ) - if validated_branch is not None and commit_info.status in { - PushStatus.COMMITTED, - PushStatus.NO_OP, - }: - workflow.git_sync_branch = validated_branch - self.session.add(workflow) - await self.session.commit() - logger.info( "Successfully published workflow", workflow_title=dsl.title, - status=commit_info.status.value, - commit_sha=commit_info.sha, - ref=commit_info.ref, - base_ref=commit_info.base_ref, - pr_url=commit_info.pr_url, - pr_number=commit_info.pr_number, - pr_reused=commit_info.pr_reused, - ) - - return WorkflowDslPublishResult( - status=commit_info.status.value, - commit_sha=commit_info.sha, - branch=commit_info.ref, - base_branch=commit_info.base_ref, - pr_url=commit_info.pr_url, - pr_number=commit_info.pr_number, - pr_reused=commit_info.pr_reused, - message=commit_info.message, + status=result.status, + commit_sha=result.commit_sha, + ref=result.branch, + base_ref=result.base_branch, + pr_url=result.pr_url, + pr_number=result.pr_number, + pr_reused=result.pr_reused, ) + return result def get_definition_path(workflow_id: WorkflowUUID) -> Path: diff --git a/tracecat/workflow/store/sync.py b/tracecat/workflow/store/sync.py deleted file mode 100644 index 82bd2d4ecc..0000000000 --- a/tracecat/workflow/store/sync.py +++ /dev/null @@ -1,846 +0,0 @@ -"""Workflow synchronization functionality for Tracecat.""" - -from __future__ import annotations - -import asyncio -import base64 -from collections.abc import Sequence -from datetime import datetime -from typing import TYPE_CHECKING, Any - -import yaml -from github.GithubException import GithubException - -if TYPE_CHECKING: - from github.ContentFile import ContentFile -from pydantic import ValidationError - -from tracecat.db.models import User -from tracecat.exceptions import TracecatNotFoundError -from tracecat.git.utils import GitUrl -from tracecat.logger import logger -from tracecat.registry.repositories.schemas import GitBranchInfo, GitCommitInfo -from tracecat.service import BaseWorkspaceService -from tracecat.sync import ( - CommitInfo, - PullDiagnostic, - PullOptions, - PullResult, - PushObject, - PushOptions, - PushStatus, -) -from tracecat.vcs.github.app import GitHubAppError, GitHubAppService -from tracecat.workflow.store.import_service import WorkflowImportService -from tracecat.workflow.store.schemas import RemoteWorkflowDefinition -from tracecat.workspaces.service import WorkspaceService - - -# NOTE: Internal service called by higher level services, shouldn't use directly -class WorkflowSyncService(BaseWorkspaceService): - """Git synchronization service for workflow definitions. - - Implements the SyncService protocol for DSLInput workflow models, - providing pull/push operations with Git repositories. - """ - - service_name = "workflow_sync" - - async def pull( - self, - *, - url: GitUrl, - options: PullOptions | None = None, - ) -> PullResult: - """Pull workflow definitions from a Git repository at specific commit SHA. - - This implementation provides atomic guarantees - either all workflows - are imported successfully or none are. - - Args: - url: Git repository URL - options: Pull options including commit SHA and conflict strategy - - Returns: - PullResult with success status and diagnostics - - Raises: - GitHubAppError: If GitHub authentication or API errors occur - """ - if not options or not options.commit_sha: - return PullResult( - success=False, - commit_sha="", - workflows_found=0, - workflows_imported=0, - diagnostics=[ - PullDiagnostic( - workflow_path="", - workflow_title=None, - error_type="validation", - message="commit_sha is required in pull options", - details={}, - ) - ], - message="commit_sha is required", - ) - - try: - # 1. Fetch repository content at specific commit SHA - repo_content = await self._fetch_repository_content(url, options.commit_sha) - - # 2. Parse workflow definitions - ( - remote_workflows, - parse_diagnostics, - ) = await self._parse_workflow_definitions(repo_content) - - if parse_diagnostics: - return PullResult( - success=False, - commit_sha=options.commit_sha, - workflows_found=len(repo_content), - workflows_imported=0, - diagnostics=parse_diagnostics, - message=f"Failed to parse {len(parse_diagnostics)} workflow definitions", - ) - - # 3. Import workflows atomically - if options.dry_run: - # For dry run, skip import and return validation-only result - return PullResult( - success=True, - commit_sha=options.commit_sha, - workflows_found=len(remote_workflows), - workflows_imported=0, - diagnostics=[], - message="Dry run completed - workflows validated but not imported", - ) - - import_service = WorkflowImportService(session=self.session, role=self.role) - - return await import_service.import_workflows_atomic( - remote_workflows=remote_workflows, - commit_sha=options.commit_sha, - ) - - except GitHubAppError as e: - logger.error(f"GitHub API error during pull: {e}") - return PullResult( - success=False, - commit_sha=options.commit_sha or "", - workflows_found=0, - workflows_imported=0, - diagnostics=[ - PullDiagnostic( - workflow_path="", - workflow_title=None, - error_type="github", - message=f"GitHub API error: {str(e)}", - details={"error": str(e)}, - ) - ], - message="GitHub API error", - ) - except Exception as e: - logger.error(f"Unexpected error during pull: {e}", exc_info=True) - return PullResult( - success=False, - commit_sha=options.commit_sha or "", - workflows_found=0, - workflows_imported=0, - diagnostics=[ - PullDiagnostic( - workflow_path="", - workflow_title=None, - error_type="system", - message=f"Unexpected error: {str(e)}", - details={"error": str(e)}, - ) - ], - message="System error", - ) - - async def _fetch_repository_content( - self, url: GitUrl, commit_sha: str - ) -> dict[str, str]: - """Fetch workflow definitions from repository at specific commit SHA. - - Args: - url: Git repository URL - commit_sha: Specific commit SHA to fetch from - - Returns: - Dictionary mapping file paths to file content - - Raises: - GitHubAppError: If GitHub API errors occur - """ - gh_svc = GitHubAppService(session=self.session, role=self.role) - gh = await gh_svc.get_github_client_for_repo(url) - - try: - repo = await asyncio.to_thread(gh.get_repo, f"{url.org}/{url.repo}") - - # Get the workflows directory at the specific commit - try: - workflows_contents = await asyncio.to_thread( - repo.get_contents, "workflows", ref=commit_sha - ) - - if not isinstance(workflows_contents, list): - # workflows is a file, not a directory - return {} - - content_map = {} - - for item in workflows_contents: - item: ContentFile = item # type hint for GitHub API object - # Look for workflow directories - if item.type == "dir": - # Get definition.yml from each workflow directory - definition_path = f"{item.path}/definition.yml" - try: - definition_file = await asyncio.to_thread( - repo.get_contents, definition_path, ref=commit_sha - ) - - if not isinstance(definition_file, list) and hasattr( - definition_file, "content" - ): - # Decode base64 content - content = base64.b64decode( - definition_file.content - ).decode("utf-8") - content_map[definition_path] = content - except GithubException as e: - if e.status != 404: # Ignore missing definition.yml files - logger.warning(f"Failed to get {definition_path}: {e}") - - return content_map - - except GithubException as e: - if e.status == 404: - # No workflows directory found - return {} - raise - - except GithubException as e: - raise GitHubAppError(f"GitHub API error: {e.status} - {e.data}") from e - finally: - gh.close() - - async def _parse_workflow_definitions( - self, content_map: dict[str, str] - ) -> tuple[list[RemoteWorkflowDefinition], list[PullDiagnostic]]: - """Parse workflow definitions from file contents. - - Args: - content_map: Dictionary mapping file paths to content - - Returns: - Tuple of (remote_workflows, diagnostics) - """ - remote_workflows: list[RemoteWorkflowDefinition] = [] - diagnostics: list[PullDiagnostic] = [] - - for file_path, content in content_map.items(): - yaml_data: dict[str, Any] | None = None - try: - # Parse YAML content - yaml_data = yaml.safe_load(content) - if not yaml_data: - diagnostics.append( - PullDiagnostic( - workflow_path=file_path, - workflow_title=None, - error_type="parse", - message="Empty or invalid YAML file", - details={}, - ) - ) - continue - - # Convert to RemoteWorkflowDefinition - remote_workflow = RemoteWorkflowDefinition.model_validate(yaml_data) - remote_workflows.append(remote_workflow) - - except yaml.YAMLError as e: - diagnostics.append( - PullDiagnostic( - workflow_path=file_path, - workflow_title=None, - error_type="parse", - message=f"YAML parsing error: {str(e)}", - details={"yaml_error": str(e)}, - ) - ) - except ValidationError as e: - diagnostics.append( - PullDiagnostic( - workflow_path=file_path, - workflow_title=yaml_data.get("definition", {}).get("title") - if isinstance(yaml_data, dict) - else None, - error_type="validation", - message=f"Validation error: {str(e)}", - details={"validation_errors": e.errors()}, - ) - ) - except Exception as e: - diagnostics.append( - PullDiagnostic( - workflow_path=file_path, - workflow_title=None, - error_type="parse", - message=f"Unexpected parsing error: {str(e)}", - details={"error": str(e)}, - ) - ) - - return remote_workflows, diagnostics - - async def push( - self, - *, - objects: Sequence[PushObject[RemoteWorkflowDefinition]], - url: GitUrl, - options: PushOptions, - ) -> CommitInfo: - """Push workflow definitions using GitHub App API operations. - - Args: - objects: PushObjects containing workflow definitions and target paths - url: Git repository URL with target branch - options: Push options including commit message and PR flag - - Returns: - CommitInfo with commit SHA and branch/PR details - """ - if len(objects) != 1: - raise ValueError("We only support pushing one workflow object at a time") - - [obj] = objects - - gh_svc = GitHubAppService(session=self.session, role=self.role) - - # Use new PyGithub-based method that handles installation resolution automatically - gh = await gh_svc.get_github_client_for_repo(url) - - try: - repo = await asyncio.to_thread(gh.get_repo, f"{url.org}/{url.repo}") - - if options.branch is not None: - return await self._push_to_target_branch( - repo=repo, - url=url, - obj=obj, - options=options, - ) - - return await self._push_legacy( - repo=repo, - url=url, - obj=obj, - options=options, - ) - - except GithubException as e: - logger.error( - "GitHub API error during push", - status=e.status, - data=e.data, - repo=f"{url.org}/{url.repo}", - ) - raise GitHubAppError(f"GitHub API error: {e.status} - {e.data}") from e - finally: - gh.close() - - async def _push_to_target_branch( - self, - *, - repo: Any, - url: GitUrl, - obj: PushObject[RemoteWorkflowDefinition], - options: PushOptions, - ) -> CommitInfo: - branch_name = options.branch - if branch_name is None: - raise ValueError("branch is required for target-branch push mode") - base_branch_name = options.pr_base_branch or url.ref or repo.default_branch - base_branch = await asyncio.to_thread(repo.get_branch, base_branch_name) - - try: - await asyncio.to_thread(repo.get_branch, branch_name) - except GithubException as e: - if e.status != 404: - raise - await asyncio.to_thread( - repo.create_git_ref, - ref=f"refs/heads/{branch_name}", - sha=base_branch.commit.sha, - ) - logger.info( - "Created target branch via GitHub API", - branch=branch_name, - base_branch=base_branch_name, - repo=f"{url.org}/{url.repo}", - ) - - file_path = obj.path_str - yaml_content = yaml.dump( - obj.data.model_dump(mode="json", exclude_none=True, exclude_unset=True), - sort_keys=False, - ) - - pr_url: str | None = None - pr_number: int | None = None - pr_reused = False - - try: - contents = await asyncio.to_thread( - repo.get_contents, file_path, ref=branch_name - ) - if isinstance(contents, list): - raise GithubException(404, {"message": "Not a file"}, {}) - existing_content = base64.b64decode(contents.content).decode("utf-8") - if existing_content == yaml_content: - if options.create_pr: - pr_url, pr_number, pr_reused = await self._upsert_pull_request_safe( - repo=repo, - url=url, - obj=obj, - options=options, - branch_name=branch_name, - base_branch_name=base_branch_name, - ) - return CommitInfo( - status=PushStatus.NO_OP, - sha=None, - ref=branch_name, - base_ref=base_branch_name, - pr_url=pr_url, - pr_number=pr_number, - pr_reused=pr_reused, - message="No changes detected; nothing to commit.", - ) - - await asyncio.to_thread( - repo.update_file, - path=contents.path, - message=options.message, - content=yaml_content, - sha=contents.sha, - branch=branch_name, - ) - logger.debug( - "Updated workflow file via API", - path=file_path, - branch=branch_name, - ) - except GithubException as e: - if e.status != 404: - raise - await asyncio.to_thread( - repo.create_file, - path=file_path, - message=options.message, - content=yaml_content, - branch=branch_name, - ) - logger.debug( - "Created workflow file via API", - path=file_path, - branch=branch_name, - ) - - branch = await asyncio.to_thread(repo.get_branch, branch_name) - commit_sha = branch.commit.sha - - if options.create_pr: - pr_url, pr_number, pr_reused = await self._upsert_pull_request_safe( - repo=repo, - url=url, - obj=obj, - options=options, - branch_name=branch_name, - base_branch_name=base_branch_name, - ) - - logger.info( - "Successfully pushed workflow via GitHub API", - branch=branch_name, - base_branch=base_branch_name, - commit_sha=commit_sha, - pr_url=pr_url, - pr_number=pr_number, - pr_reused=pr_reused, - ) - - return CommitInfo( - status=PushStatus.COMMITTED, - sha=commit_sha, - ref=branch_name, - base_ref=base_branch_name, - pr_url=pr_url, - pr_number=pr_number, - pr_reused=pr_reused, - message="Committed workflow changes.", - ) - - async def _push_legacy( - self, - *, - repo: Any, - url: GitUrl, - obj: PushObject[RemoteWorkflowDefinition], - options: PushOptions, - ) -> CommitInfo: - base_branch_name = url.ref or repo.default_branch - base_branch = await asyncio.to_thread(repo.get_branch, base_branch_name) - - timestamp = datetime.now().strftime("%Y%m%d-%H%M%S") - branch_name = f"tracecat-sync-{timestamp}" - - logger.info( - "Creating legacy temp branch via GitHub API", - branch=branch_name, - base_branch=base_branch_name, - repo=f"{url.org}/{url.repo}", - ) - - await asyncio.to_thread( - repo.create_git_ref, - ref=f"refs/heads/{branch_name}", - sha=base_branch.commit.sha, - ) - - file_path = obj.path_str - yaml_content = yaml.dump( - obj.data.model_dump(mode="json", exclude_none=True, exclude_unset=True), - sort_keys=False, - ) - - try: - contents = await asyncio.to_thread( - repo.get_contents, file_path, ref=branch_name - ) - if isinstance(contents, list): - raise GithubException(404, {"message": "Not a file"}, {}) - await asyncio.to_thread( - repo.update_file, - path=contents.path, - message=options.message, - content=yaml_content, - sha=contents.sha, - branch=branch_name, - ) - logger.debug( - "Updated workflow file via API (legacy)", - path=file_path, - branch=branch_name, - ) - except GithubException as e: - if e.status != 404: - raise - await asyncio.to_thread( - repo.create_file, - path=file_path, - message=options.message, - content=yaml_content, - branch=branch_name, - ) - logger.debug( - "Created workflow file via API (legacy)", - path=file_path, - branch=branch_name, - ) - - branch = await asyncio.to_thread(repo.get_branch, branch_name) - commit_sha = branch.commit.sha - - pr_url: str | None = None - pr_number: int | None = None - if options.create_pr: - try: - pr_url, pr_number, _ = await self._create_pull_request( - repo=repo, - obj=obj, - options=options, - branch_name=branch_name, - base_branch_name=base_branch_name, - ) - except GithubException as e: - logger.error( - "Failed to create PR via GitHub API (legacy)", - error=str(e), - branch=branch_name, - ) - - return CommitInfo( - status=PushStatus.COMMITTED, - sha=commit_sha, - ref=branch_name, - base_ref=base_branch_name, - pr_url=pr_url, - pr_number=pr_number, - pr_reused=False, - message="Committed workflow changes on a temporary legacy branch.", - ) - - async def _upsert_pull_request( - self, - *, - repo: Any, - url: GitUrl, - obj: PushObject[RemoteWorkflowDefinition], - options: PushOptions, - branch_name: str, - base_branch_name: str, - ) -> tuple[str | None, int | None, bool]: - def _first_open_pull_request() -> Any | None: - pulls = repo.get_pulls( - state="open", - head=f"{url.org}:{branch_name}", - base=base_branch_name, - ) - return next(iter(pulls), None) - - try: - existing_pr = await asyncio.to_thread(_first_open_pull_request) - except GithubException as e: - logger.error( - "Failed to search for open pull requests", - error=str(e), - branch=branch_name, - base_branch=base_branch_name, - ) - existing_pr = None - - if existing_pr is not None: - logger.info( - "Reused existing pull request", - pr_number=existing_pr.number, - pr_url=existing_pr.html_url, - branch=branch_name, - base_branch=base_branch_name, - ) - return existing_pr.html_url, existing_pr.number, True - - return await self._create_pull_request( - repo=repo, - obj=obj, - options=options, - branch_name=branch_name, - base_branch_name=base_branch_name, - ) - - async def _upsert_pull_request_safe( - self, - *, - repo: Any, - url: GitUrl, - obj: PushObject[RemoteWorkflowDefinition], - options: PushOptions, - branch_name: str, - base_branch_name: str, - ) -> tuple[str | None, int | None, bool]: - try: - return await self._upsert_pull_request( - repo=repo, - url=url, - obj=obj, - options=options, - branch_name=branch_name, - base_branch_name=base_branch_name, - ) - except GithubException as e: - logger.error( - "Failed to create or reuse pull request via GitHub API", - status=e.status, - data=e.data, - branch=branch_name, - base_branch=base_branch_name, - repo=f"{url.org}/{url.repo}", - ) - return None, None, False - - async def _create_pull_request( - self, - *, - repo: Any, - obj: PushObject[RemoteWorkflowDefinition], - options: PushOptions, - branch_name: str, - base_branch_name: str, - ) -> tuple[str, int, bool]: - ws_svc = WorkspaceService(session=self.session, role=self.role) - workspace = await ws_svc.get_workspace(self.workspace_id) - if not workspace: - raise TracecatNotFoundError("Workspace not found") - - try: - title = obj.data.definition.title - description = obj.data.definition.description - except ValueError: - title = "" - description = "" - - current_user = None - if self.role.user_id is not None: - try: - current_user = await self.session.get(User, self.role.user_id) - except Exception: - current_user = None - - published_by = current_user.email if current_user else "" - - pr = await asyncio.to_thread( - repo.create_pull, - title=options.message, - body=( - f"Automated workflow sync from Tracecat\n\n" - f"**Workspace:** {workspace.name}\n" - f"**Published by:** {published_by}\n" - f"**Workflow Title:** {title}\n" - f"**Workflow Description:** {description}" - ), - head=branch_name, - base=base_branch_name, - ) - - logger.info( - "Created PR via GitHub API", - pr_number=pr.number, - pr_url=pr.html_url, - branch=branch_name, - base_branch=base_branch_name, - ) - return pr.html_url, pr.number, False - - async def list_commits( - self, - *, - url: GitUrl, - branch: str = "main", - limit: int = 10, - ) -> list[GitCommitInfo]: - """List commits from a Git repository using GitHub App API. - - Args: - url: Git repository URL - branch: Branch name to fetch commits from - limit: Maximum number of commits to return - - Returns: - List of GitCommitInfo objects with commit details - - Raises: - GitHubAppError: If GitHub authentication or API errors occur - """ - try: - # Get authenticated GitHub client - gh_svc = GitHubAppService(session=self.session, role=self.role) - gh = await gh_svc.get_github_client_for_repo(url) - - try: - # Get repository object - repo = await asyncio.to_thread(gh.get_repo, f"{url.org}/{url.repo}") - - # Fetch commits using PyGithub - commits_paginated = await asyncio.to_thread( - repo.get_commits, sha=branch - ) - - # Get all tags to build SHA-to-tags mapping - tags_paginated = await asyncio.to_thread(repo.get_tags) - sha_to_tags: dict[str, list[str]] = {} - - # Build mapping of commit SHA to tag names in thread to avoid blocking - def build_tag_mapping(): - result_map = {} - for tag in tags_paginated: - tag_sha = tag.commit.sha - if tag_sha not in result_map: - result_map[tag_sha] = [] - result_map[tag_sha].append(tag.name) - return result_map - - sha_to_tags = await asyncio.to_thread(build_tag_mapping) - - # Convert to GitCommitInfo objects - commits = [] - count = 0 - for commit in commits_paginated: - if count >= limit: - break - - # Get tags for this commit SHA, default to empty list - tags = sha_to_tags.get(commit.sha, []) - - commits.append( - GitCommitInfo( - sha=commit.sha, - message=commit.commit.message, - author=commit.commit.author.name or "Unknown", - author_email=commit.commit.author.email or "", - date=commit.commit.author.date.isoformat(), - tags=tags, - ) - ) - count += 1 - - return commits - - finally: - gh.close() - - except GithubException as e: - logger.error( - "GitHub API error during commit listing", - status=e.status, - data=e.data, - repo=f"{url.org}/{url.repo}", - branch=branch, - ) - raise GitHubAppError(f"GitHub API error: {e.status} - {e.data}") from e - - async def list_branches( - self, - *, - url: GitUrl, - limit: int = 100, - ) -> list[GitBranchInfo]: - """List branches from a Git repository using GitHub App API.""" - try: - gh_svc = GitHubAppService(session=self.session, role=self.role) - gh = await gh_svc.get_github_client_for_repo(url) - - try: - repo = await asyncio.to_thread(gh.get_repo, f"{url.org}/{url.repo}") - branches_paginated = await asyncio.to_thread(repo.get_branches) - - branches: list[GitBranchInfo] = [] - count = 0 - for branch_obj in branches_paginated: - if count >= limit: - break - branches.append( - GitBranchInfo( - name=branch_obj.name, - is_default=branch_obj.name == repo.default_branch, - ) - ) - count += 1 - - return branches - finally: - gh.close() - except GithubException as e: - logger.error( - "GitHub API error during branch listing", - status=e.status, - data=e.data, - repo=f"{url.org}/{url.repo}", - ) - raise GitHubAppError(f"GitHub API error: {e.status} - {e.data}") from e diff --git a/tracecat/workspace_sync/__init__.py b/tracecat/workspace_sync/__init__.py new file mode 100644 index 0000000000..0f0722a714 --- /dev/null +++ b/tracecat/workspace_sync/__init__.py @@ -0,0 +1 @@ +"""Workspace-level Git sync projection and reconciliation.""" diff --git a/tracecat/workspace_sync/enums.py b/tracecat/workspace_sync/enums.py new file mode 100644 index 0000000000..b07f13ab5c --- /dev/null +++ b/tracecat/workspace_sync/enums.py @@ -0,0 +1,65 @@ +"""Workspace sync enum values.""" + +from __future__ import annotations + +from enum import StrEnum + + +class SyncProvider(StrEnum): + GIT = "git" + + +class SyncResourceType(StrEnum): + WORKFLOW = "workflow" + + +class SyncStateStatus(StrEnum): + NEVER_SYNCED = "never_synced" + CLEAN = "clean" + LOCAL_DIRTY = "local_dirty" + REMOTE_AHEAD = "remote_ahead" + DIVERGED = "diverged" + CONFLICTED = "conflicted" + ERROR = "error" + + +class SyncDirection(StrEnum): + PULL = "pull" + PUSH = "push" + + +class ResourceSyncStatus(StrEnum): + UNTRACKED = "untracked" + SYNCED = "synced" + LOCAL_DIRTY = "local_dirty" + REMOTE_DIRTY = "remote_dirty" + CONFLICTED = "conflicted" + ERROR = "error" + + +class SyncOperation(StrEnum): + CREATE = "create" + UPDATE = "update" + DELETE = "delete" + ARCHIVE = "archive" + DISABLE = "disable" + + +class ChangeSetStatus(StrEnum): + OPEN = "open" + VALIDATED = "validated" + EXPORTED = "exported" + FAILED = "failed" + + +class ValidationStatus(StrEnum): + PENDING = "pending" + VALID = "valid" + INVALID = "invalid" + + +class MaterializationStatus(StrEnum): + PENDING = "pending" + COMMITTED = "committed" + NO_OP = "no_op" + FAILED = "failed" diff --git a/tracecat/workspace_sync/git.py b/tracecat/workspace_sync/git.py new file mode 100644 index 0000000000..8bc3e6390d --- /dev/null +++ b/tracecat/workspace_sync/git.py @@ -0,0 +1,307 @@ +"""GitHub-backed repo operations for workspace sync.""" + +from __future__ import annotations + +import asyncio +import base64 +from dataclasses import dataclass +from typing import Any + +from github.GithubException import GithubException +from github.InputGitTreeElement import InputGitTreeElement + +from tracecat.db.models import User +from tracecat.exceptions import TracecatNotFoundError +from tracecat.git.types import GitUrl +from tracecat.registry.repositories.schemas import GitBranchInfo, GitCommitInfo +from tracecat.service import BaseWorkspaceService +from tracecat.sync import CommitInfo, PushStatus +from tracecat.vcs.github.app import GitHubAppError, GitHubAppService +from tracecat.workspaces.service import WorkspaceService + + +@dataclass(frozen=True) +class GitTreeSnapshot: + commit_sha: str + tree_sha: str | None + files: dict[str, str] + + +class WorkspaceGitHubSyncService(BaseWorkspaceService): + """GitHub App transport for workspace sync.""" + + service_name = "workspace_github_sync" + + async def read_files( + self, + *, + url: GitUrl, + ref: str, + ) -> GitTreeSnapshot: + gh_svc = GitHubAppService(session=self.session, role=self.role) + gh = await gh_svc.get_github_client_for_repo(url) + try: + repo = await asyncio.to_thread(gh.get_repo, f"{url.org}/{url.repo}") + commit = await asyncio.to_thread(repo.get_commit, ref) + tree = await asyncio.to_thread( + repo.get_git_tree, + sha=commit.sha, + recursive=True, + ) + files: dict[str, str] = {} + for item in tree.tree: + if item.type != "blob" or not item.path: + continue + content_file = await asyncio.to_thread( + repo.get_contents, + item.path, + ref=commit.sha, + ) + if isinstance(content_file, list): + continue + files[item.path] = base64.b64decode(content_file.content).decode( + "utf-8" + ) + return GitTreeSnapshot( + commit_sha=commit.sha, + tree_sha=getattr(commit.commit.tree, "sha", None), + files=files, + ) + except GithubException as e: + raise GitHubAppError(f"GitHub API error: {e.status} - {e.data}") from e + finally: + gh.close() + + async def write_files( + self, + *, + url: GitUrl, + files: dict[str, str], + message: str, + branch: str, + create_pr: bool, + pr_base_branch: str | None = None, + ) -> CommitInfo: + if not files: + raise ValueError("At least one file is required for workspace sync export") + gh_svc = GitHubAppService(session=self.session, role=self.role) + gh = await gh_svc.get_github_client_for_repo(url) + try: + repo = await asyncio.to_thread(gh.get_repo, f"{url.org}/{url.repo}") + base_branch_name = pr_base_branch or url.ref or repo.default_branch + base_branch = await asyncio.to_thread(repo.get_branch, base_branch_name) + + try: + target_branch = await asyncio.to_thread(repo.get_branch, branch) + except GithubException as e: + if e.status != 404: + raise + await asyncio.to_thread( + repo.create_git_ref, + ref=f"refs/heads/{branch}", + sha=base_branch.commit.sha, + ) + target_branch = await asyncio.to_thread(repo.get_branch, branch) + + changed_files: dict[str, str] = {} + for path, content in files.items(): + existing_content: str | None = None + try: + existing = await asyncio.to_thread( + repo.get_contents, + path, + ref=branch, + ) + if not isinstance(existing, list): + existing_content = base64.b64decode(existing.content).decode( + "utf-8" + ) + except GithubException as e: + if e.status != 404: + raise + if existing_content != content: + changed_files[path] = content + + pr_url: str | None = None + pr_number: int | None = None + pr_reused = False + if not changed_files: + if create_pr: + pr_url, pr_number, pr_reused = await self._upsert_pull_request( + repo=repo, + url=url, + title=message, + branch_name=branch, + base_branch_name=base_branch_name, + ) + return CommitInfo( + status=PushStatus.NO_OP, + sha=None, + ref=branch, + base_ref=base_branch_name, + pr_url=pr_url, + pr_number=pr_number, + pr_reused=pr_reused, + message="No changes detected; nothing to commit.", + ) + + target_commit = await asyncio.to_thread( + repo.get_git_commit, + target_branch.commit.sha, + ) + elements = [] + for path, content in sorted(changed_files.items()): + blob = await asyncio.to_thread(repo.create_git_blob, content, "utf-8") + elements.append( + InputGitTreeElement( + path=path, + mode="100644", + type="blob", + sha=blob.sha, + ) + ) + + tree = await asyncio.to_thread( + repo.create_git_tree, + elements, + base_tree=target_commit.tree, + ) + commit = await asyncio.to_thread( + repo.create_git_commit, + message, + tree, + [target_commit], + ) + ref = await asyncio.to_thread(repo.get_git_ref, f"heads/{branch}") + await asyncio.to_thread(ref.edit, sha=commit.sha) + + if create_pr: + pr_url, pr_number, pr_reused = await self._upsert_pull_request( + repo=repo, + url=url, + title=message, + branch_name=branch, + base_branch_name=base_branch_name, + ) + + return CommitInfo( + status=PushStatus.COMMITTED, + sha=commit.sha, + ref=branch, + base_ref=base_branch_name, + pr_url=pr_url, + pr_number=pr_number, + pr_reused=pr_reused, + message="Committed workspace sync changes.", + ) + except GithubException as e: + raise GitHubAppError(f"GitHub API error: {e.status} - {e.data}") from e + finally: + gh.close() + + async def list_commits( + self, + *, + url: GitUrl, + branch: str = "main", + limit: int = 10, + ) -> list[GitCommitInfo]: + gh_svc = GitHubAppService(session=self.session, role=self.role) + gh = await gh_svc.get_github_client_for_repo(url) + try: + repo = await asyncio.to_thread(gh.get_repo, f"{url.org}/{url.repo}") + commits_paginated = await asyncio.to_thread(repo.get_commits, sha=branch) + commits: list[GitCommitInfo] = [] + for index, commit in enumerate(commits_paginated): + if index >= limit: + break + commits.append( + GitCommitInfo( + sha=commit.sha, + message=commit.commit.message, + author=commit.commit.author.name or "Unknown", + author_email=commit.commit.author.email or "", + date=commit.commit.author.date.isoformat(), + tags=[], + ) + ) + return commits + except GithubException as e: + raise GitHubAppError(f"GitHub API error: {e.status} - {e.data}") from e + finally: + gh.close() + + async def list_branches( + self, + *, + url: GitUrl, + limit: int = 100, + ) -> list[GitBranchInfo]: + gh_svc = GitHubAppService(session=self.session, role=self.role) + gh = await gh_svc.get_github_client_for_repo(url) + try: + repo = await asyncio.to_thread(gh.get_repo, f"{url.org}/{url.repo}") + branches_paginated = await asyncio.to_thread(repo.get_branches) + branches: list[GitBranchInfo] = [] + for index, branch_obj in enumerate(branches_paginated): + if index >= limit: + break + branches.append( + GitBranchInfo( + name=branch_obj.name, + is_default=branch_obj.name == repo.default_branch, + ) + ) + return branches + except GithubException as e: + raise GitHubAppError(f"GitHub API error: {e.status} - {e.data}") from e + finally: + gh.close() + + async def _upsert_pull_request( + self, + *, + repo: Any, + url: GitUrl, + title: str, + branch_name: str, + base_branch_name: str, + ) -> tuple[str | None, int | None, bool]: + def _first_open_pull_request() -> Any | None: + pulls = repo.get_pulls( + state="open", + head=f"{url.org}:{branch_name}", + base=base_branch_name, + ) + return next(iter(pulls), None) + + existing_pr = await asyncio.to_thread(_first_open_pull_request) + if existing_pr is not None: + return existing_pr.html_url, existing_pr.number, True + + workspace = await WorkspaceService( + session=self.session, role=self.role + ).get_workspace(self.workspace_id) + if workspace is None: + raise TracecatNotFoundError("Workspace not found") + + current_user = None + if self.role.user_id is not None: + try: + current_user = await self.session.get(User, self.role.user_id) + except Exception: + current_user = None + + published_by = current_user.email if current_user else "" + pr = await asyncio.to_thread( + repo.create_pull, + title=title, + body=( + "Automated workspace sync from Tracecat\n\n" + f"**Workspace:** {workspace.name}\n" + f"**Published by:** {published_by}" + ), + head=branch_name, + base=base_branch_name, + ) + return pr.html_url, pr.number, False diff --git a/tracecat/workspace_sync/schemas.py b/tracecat/workspace_sync/schemas.py new file mode 100644 index 0000000000..aee7c01768 --- /dev/null +++ b/tracecat/workspace_sync/schemas.py @@ -0,0 +1,189 @@ +"""Schemas for canonical workspace Git sync specs and APIs.""" + +from __future__ import annotations + +import uuid +from typing import Any, Literal + +from pydantic import BaseModel, Field, field_validator, model_validator + +from tracecat.core.schemas import Schema +from tracecat.dsl.common import DSLInput +from tracecat.sync import CommitInfo, PullDiagnostic +from tracecat.workflow.store.schemas import ( + RemoteCaseTrigger, + RemoteWebhook, + RemoteWorkflowSchedule, + RemoteWorkflowTag, + WorkflowDslPublishResult, +) +from tracecat.workspace_sync.enums import SyncOperation, SyncStateStatus + +MANIFEST_FILENAME = "tracecat.json" +WORKFLOW_ROOT = "workflows" +WORKFLOW_DEFINITION_FILENAME = "definition.yml" + + +class WorkspaceManifestResources(BaseModel): + workflows: str = f"{WORKFLOW_ROOT}/" + + +class WorkspaceManifest(BaseModel): + version: Literal[1] = 1 + resources: WorkspaceManifestResources = Field( + default_factory=WorkspaceManifestResources + ) + + +class WorkflowResourceSpec(BaseModel): + """Canonical Git-owned desired state for a workflow resource.""" + + version: Literal[1] = 1 + type: Literal["workflow"] = "workflow" + id: str = Field(min_length=1) + alias: str | None = None + folder_path: str | None = None + tags: list[RemoteWorkflowTag] | None = None + schedules: list[RemoteWorkflowSchedule] | None = None + webhook: RemoteWebhook | None = None + case_trigger: RemoteCaseTrigger | None = None + definition: DSLInput + + @field_validator("id") + @classmethod + def validate_source_id(cls, value: str) -> str: + cleaned = value.strip().strip("/") + if not cleaned: + raise ValueError("workflow source id cannot be empty") + if "/" in cleaned or "\\" in cleaned: + raise ValueError("workflow source id must be a single path segment") + return cleaned + + +class WorkspaceSpec(BaseModel): + version: Literal[1] = 1 + workflows: dict[str, WorkflowResourceSpec] = Field(default_factory=dict) + + @model_validator(mode="after") + def validate_workflow_keys(self) -> WorkspaceSpec: + for source_id, spec in self.workflows.items(): + if source_id != spec.id: + raise ValueError( + f"Workflow map key {source_id!r} does not match spec id {spec.id!r}" + ) + return self + + +class ProjectedFile(BaseModel): + path: str + content: str + + +class WorkspaceProjection(BaseModel): + manifest: WorkspaceManifest + spec: WorkspaceSpec + files: dict[str, str] + spec_hash: str + + +class WorkspaceRemoteSnapshot(BaseModel): + commit_sha: str + tree_sha: str | None = None + files: dict[str, str] + spec: WorkspaceSpec + spec_hash: str + + +class ResourceRef(BaseModel): + resource_type: str + source_id: str + source_path: str | None = None + local_id: uuid.UUID | None = None + + +class WorkspaceSyncStatus(BaseModel): + status: SyncStateStatus + base_spec_hash: str | None + local_spec_hash: str + remote_spec_hash: str | None = None + base_commit_sha: str | None = None + remote_commit_sha: str | None = None + target_ref: str | None = None + pending_change_count: int = 0 + diagnostics: list[PullDiagnostic] = Field(default_factory=list) + + +class WorkspaceSyncPullPreview(BaseModel): + success: bool + commit_sha: str + workflows_found: int + diagnostics: list[PullDiagnostic] = Field(default_factory=list) + message: str + + +class WorkspaceSyncPullRequest(BaseModel): + commit_sha: str = Field(min_length=40, max_length=64) + dry_run: bool = False + force: bool = False + + +class ChangeSetCreate(BaseModel): + title: str = Field(min_length=1) + description: str | None = None + resources: list[ResourceRef] + + +class ChangeSetExport(BaseModel): + message: str + branch: str + create_pr: bool = False + pr_base_branch: str | None = None + + +class WorkspaceSyncPendingChange(BaseModel): + resource_type: str + source_id: str + source_path: str + local_id: uuid.UUID | None = None + operation: SyncOperation + title: str | None = None + alias: str | None = None + before_spec_hash: str | None = None + after_spec_hash: str | None = None + exportable: bool = True + + +class WorkspaceSyncPendingChanges(BaseModel): + base_spec_hash: str | None = None + local_spec_hash: str + changes: list[WorkspaceSyncPendingChange] = Field(default_factory=list) + + +class ChangeSetRead(Schema): + id: uuid.UUID + title: str + description: str | None = None + base_commit_sha: str | None = None + base_spec_hash: str | None = None + selected_resources: list[dict[str, Any]] + selected_paths: list[str] + validation_status: str + validation_result: dict[str, Any] + status: str + + +class WorkspaceSyncExportResult(BaseModel): + changeset_id: uuid.UUID + commit: CommitInfo + + def as_workflow_publish_result(self) -> WorkflowDslPublishResult: + return WorkflowDslPublishResult( + status=self.commit.status.value, + commit_sha=self.commit.sha, + branch=self.commit.ref, + base_branch=self.commit.base_ref, + pr_url=self.commit.pr_url, + pr_number=self.commit.pr_number, + pr_reused=self.commit.pr_reused, + message=self.commit.message, + ) diff --git a/tracecat/workspace_sync/serialization.py b/tracecat/workspace_sync/serialization.py new file mode 100644 index 0000000000..a6f906d112 --- /dev/null +++ b/tracecat/workspace_sync/serialization.py @@ -0,0 +1,50 @@ +"""Canonical serialization helpers for workspace sync specs.""" + +from __future__ import annotations + +from hashlib import sha256 +from typing import Any + +import orjson +from pydantic import BaseModel +from pydantic_core import to_jsonable_python + +CANONICAL_HASH_VERSION = 1 + + +def canonical_data(value: Any) -> Any: + """Convert a value into JSON-compatible data with omitted null model fields.""" + if isinstance(value, BaseModel): + return value.model_dump(mode="json", exclude_none=True) + return to_jsonable_python(value, fallback=str) + + +def canonical_hash_data(value: Any) -> Any: + """Convert a value into stable hash data, ignoring model defaults.""" + if isinstance(value, BaseModel): + return value.model_dump( + mode="json", + exclude_none=True, + exclude_defaults=True, + ) + return to_jsonable_python(value, fallback=str) + + +def canonical_json_bytes(value: Any, *, pretty: bool = False) -> bytes: + """Serialize JSON deterministically.""" + option = orjson.OPT_SORT_KEYS + if pretty: + option |= orjson.OPT_INDENT_2 + return orjson.dumps(canonical_data(value), option=option) + + +def canonical_json_text(value: Any, *, pretty: bool = True) -> str: + """Serialize JSON deterministically as UTF-8 text with a trailing newline.""" + return canonical_json_bytes(value, pretty=pretty).decode("utf-8") + "\n" + + +def stable_hash(value: Any) -> str: + """Return a versioned SHA-256 hash for canonical sync comparisons.""" + payload = orjson.dumps(canonical_hash_data(value), option=orjson.OPT_SORT_KEYS) + digest = sha256(payload).hexdigest() + return f"v{CANONICAL_HASH_VERSION}:{digest}" diff --git a/tracecat/workspace_sync/service.py b/tracecat/workspace_sync/service.py new file mode 100644 index 0000000000..75922e25cd --- /dev/null +++ b/tracecat/workspace_sync/service.py @@ -0,0 +1,1076 @@ +"""Workspace Git sync projection, reconciliation, and ChangeSet service.""" + +from __future__ import annotations + +import uuid +from collections.abc import Sequence +from datetime import UTC, datetime +from typing import Any + +from sqlalchemy import desc, select +from sqlalchemy.dialects.postgresql import insert + +from tracecat.db.models import ( + Workflow, + Workspace, + WorkspaceSyncChangeSet, + WorkspaceSyncChangeSetItem, + WorkspaceSyncMaterialization, + WorkspaceSyncResourceMapping, + WorkspaceSyncState, +) +from tracecat.dsl.common import DSLInput +from tracecat.exceptions import ( + TracecatNotFoundError, + TracecatSettingsError, + TracecatValidationError, +) +from tracecat.git.types import GitUrl +from tracecat.git.utils import parse_git_url +from tracecat.identifiers.workflow import WorkflowUUID +from tracecat.service import BaseWorkspaceService +from tracecat.sync import PullDiagnostic, PullOptions, PullResult, PushOptions +from tracecat.workflow.management.definitions import WorkflowDefinitionsService +from tracecat.workflow.management.management import WorkflowsManagementService +from tracecat.workflow.store.import_service import WorkflowImportService +from tracecat.workflow.store.schemas import WorkflowDslPublishResult +from tracecat.workspace_sync.enums import ( + ChangeSetStatus, + MaterializationStatus, + ResourceSyncStatus, + SyncDirection, + SyncOperation, + SyncProvider, + SyncResourceType, + SyncStateStatus, + ValidationStatus, +) +from tracecat.workspace_sync.git import WorkspaceGitHubSyncService +from tracecat.workspace_sync.schemas import ( + MANIFEST_FILENAME, + ChangeSetCreate, + ChangeSetExport, + ChangeSetRead, + ResourceRef, + WorkflowResourceSpec, + WorkspaceManifest, + WorkspaceProjection, + WorkspaceRemoteSnapshot, + WorkspaceSpec, + WorkspaceSyncExportResult, + WorkspaceSyncPendingChange, + WorkspaceSyncPendingChanges, + WorkspaceSyncStatus, +) +from tracecat.workspace_sync.serialization import ( + canonical_json_text, + stable_hash, +) +from tracecat.workspace_sync.workflow import ( + default_workflow_source_id, + parse_workflow_spec, + serialize_workflow_spec, + workflow_source_path, + workflow_spec_from_orm, + workflow_spec_to_remote, +) +from tracecat.workspaces.service import WorkspaceService + + +class WorkspaceGitSyncService(BaseWorkspaceService): + """Workspace-level Git sync service.""" + + service_name = "workspace_git_sync" + + async def project_workspace( + self, + *, + workflow_ids: Sequence[WorkflowUUID] | None = None, + create_missing_mappings: bool = True, + ) -> WorkspaceProjection: + workflows = await self._list_projectable_workflows(workflow_ids=workflow_ids) + defn_service = WorkflowDefinitionsService(session=self.session, role=self.role) + mgmt_service = WorkflowsManagementService(session=self.session, role=self.role) + specs: dict[str, WorkflowResourceSpec] = {} + + for workflow in workflows: + await self.session.refresh( + workflow, + ["tags", "folder", "schedules", "webhook", "case_trigger"], + ) + dsl = await self._get_workflow_dsl( + workflow, + defn_service=defn_service, + mgmt_service=mgmt_service, + ) + preferred_source_id = default_workflow_source_id( + alias=workflow.alias, + title=dsl.title, + ) + mapping = await self._ensure_resource_mapping( + resource_type=SyncResourceType.WORKFLOW.value, + local_id=WorkflowUUID.new(workflow.id), + preferred_source_id=preferred_source_id, + source_path=workflow_source_path(preferred_source_id), + create=create_missing_mappings, + reserved_source_ids=set(specs), + ) + source_id = mapping.source_id if mapping else preferred_source_id + spec = workflow_spec_from_orm(workflow, dsl=dsl, source_id=source_id) + specs[source_id] = spec + + manifest = WorkspaceManifest() + spec = WorkspaceSpec(workflows=dict(sorted(specs.items()))) + files = self._files_from_spec(manifest=manifest, spec=spec) + return WorkspaceProjection( + manifest=manifest, + spec=spec, + files=files, + spec_hash=stable_hash(spec), + ) + + async def parse_files( + self, + files: dict[str, str], + *, + commit_sha: str = "", + tree_sha: str | None = None, + ) -> tuple[WorkspaceRemoteSnapshot, list[PullDiagnostic]]: + diagnostics: list[PullDiagnostic] = [] + manifest = WorkspaceManifest() + if manifest_content := files.get(MANIFEST_FILENAME): + try: + manifest = WorkspaceManifest.model_validate_json(manifest_content) + except Exception as e: + diagnostics.append( + PullDiagnostic( + workflow_path=MANIFEST_FILENAME, + workflow_title=None, + error_type="parse", + message=f"Invalid workspace manifest: {str(e)}", + details={"error": str(e)}, + ) + ) + return ( + WorkspaceRemoteSnapshot( + commit_sha=commit_sha, + tree_sha=tree_sha, + files=files, + spec=WorkspaceSpec(), + spec_hash=stable_hash(WorkspaceSpec()), + ), + diagnostics, + ) + + workflow_root = manifest.resources.workflows.strip("/") + workflows: dict[str, WorkflowResourceSpec] = {} + for path, content in sorted(files.items()): + if not path.startswith(f"{workflow_root}/"): + continue + spec, diagnostic = parse_workflow_spec(path, content) + if diagnostic is not None: + diagnostics.append(diagnostic) + continue + if spec is not None: + workflows[spec.id] = spec + + spec = WorkspaceSpec(workflows=dict(sorted(workflows.items()))) + return ( + WorkspaceRemoteSnapshot( + commit_sha=commit_sha, + tree_sha=tree_sha, + files=files, + spec=spec, + spec_hash=stable_hash(spec), + ), + diagnostics, + ) + + async def pull( + self, + *, + url: GitUrl, + options: PullOptions, + ) -> PullResult: + if not options.commit_sha: + return PullResult( + success=False, + commit_sha="", + workflows_found=0, + workflows_imported=0, + diagnostics=[ + PullDiagnostic( + workflow_path="", + workflow_title=None, + error_type="validation", + message="commit_sha is required", + details={}, + ) + ], + message="commit_sha is required", + ) + + git_svc = WorkspaceGitHubSyncService(session=self.session, role=self.role) + remote_tree = await git_svc.read_files(url=url, ref=options.commit_sha) + snapshot, diagnostics = await self.parse_files( + remote_tree.files, + commit_sha=remote_tree.commit_sha, + tree_sha=remote_tree.tree_sha, + ) + if diagnostics: + return PullResult( + success=False, + commit_sha=snapshot.commit_sha, + workflows_found=len(snapshot.spec.workflows), + workflows_imported=0, + diagnostics=diagnostics, + message=f"Failed to parse {len(diagnostics)} workflow definition(s)", + ) + + if options.dry_run: + return PullResult( + success=True, + commit_sha=snapshot.commit_sha, + workflows_found=len(snapshot.spec.workflows), + workflows_imported=0, + diagnostics=[], + message="Dry run completed - workspace spec validated but not applied", + ) + + local_projection = await self.project_workspace(create_missing_mappings=True) + state = await self._get_or_create_state(url=url) + if state.base_spec_hash and local_projection.spec_hash != state.base_spec_hash: + return PullResult( + success=False, + commit_sha=snapshot.commit_sha, + workflows_found=len(snapshot.spec.workflows), + workflows_imported=0, + diagnostics=[ + PullDiagnostic( + workflow_path="", + workflow_title=None, + error_type="conflict", + message=( + "Local syncable workspace state changed since the last " + "synced base. Export or discard local changes before pulling." + ), + details={ + "base_spec_hash": state.base_spec_hash, + "local_spec_hash": local_projection.spec_hash, + }, + ) + ], + message="Pull blocked by local workspace drift", + ) + + result = await self._reconcile_workflow_specs( + spec=snapshot.spec, + commit_sha=snapshot.commit_sha, + ) + if result.success: + await self._record_successful_pull( + state=state, + snapshot=snapshot, + url=url, + ) + return result + + async def get_status(self) -> WorkspaceSyncStatus: + """Return the workspace/repo three-way sync status for the configured repo.""" + url = await self._workspace_git_url() + state = await self._get_state(url=url) or self._unsaved_state(url=url) + local_projection = await self.project_workspace(create_missing_mappings=False) + + remote_snapshot: WorkspaceRemoteSnapshot | None = None + diagnostics: list[PullDiagnostic] = [] + remote_ref = state.target_ref or url.ref or "main" + try: + remote_snapshot, diagnostics = await self._read_remote_snapshot( + url=url, + ref=remote_ref, + ) + except Exception as e: + diagnostics.append( + PullDiagnostic( + workflow_path="", + workflow_title=None, + error_type="github", + message=f"Failed to read remote workspace spec: {str(e)}", + details={"error": str(e), "ref": remote_ref}, + ) + ) + + pending = await self._pending_changes_from_projection( + projection=local_projection, + state=state, + ) + remote_changed_source_ids = await self._remote_changed_source_ids( + remote_snapshot + ) + remote_spec_hash = remote_snapshot.spec_hash if remote_snapshot else None + status = self._classify_status( + base_spec_hash=state.base_spec_hash, + local_spec_hash=local_projection.spec_hash, + remote_spec_hash=remote_spec_hash, + local_changed_source_ids={change.source_id for change in pending.changes}, + remote_changed_source_ids=remote_changed_source_ids, + remote_diagnostics=diagnostics, + ) + + return WorkspaceSyncStatus( + status=status, + base_spec_hash=state.base_spec_hash, + local_spec_hash=local_projection.spec_hash, + remote_spec_hash=remote_spec_hash, + base_commit_sha=state.base_commit_sha, + remote_commit_sha=remote_snapshot.commit_sha if remote_snapshot else None, + target_ref=remote_ref, + pending_change_count=len(pending.changes), + diagnostics=diagnostics, + ) + + async def list_pending_changes(self) -> WorkspaceSyncPendingChanges: + """List local syncable changes relative to the last synced base.""" + url = await self._workspace_git_url() + state = await self._get_state(url=url) or self._unsaved_state(url=url) + projection = await self.project_workspace(create_missing_mappings=False) + return await self._pending_changes_from_projection( + projection=projection, + state=state, + ) + + async def create_changeset(self, params: ChangeSetCreate) -> ChangeSetRead: + """Create a reviewable ChangeSet from selected pending resources.""" + if not params.resources: + raise TracecatValidationError("At least one resource is required") + + projection = await self.project_workspace(create_missing_mappings=True) + specs = self._select_workflow_specs( + projection=projection, + resources=params.resources, + ) + selected_files = self._files_for_workflow_specs(specs) + changeset = await self._create_changeset_for_specs( + title=params.title, + description=params.description, + specs=specs, + selected_files=selected_files, + ) + await self.session.commit() + return self._changeset_to_read(changeset) + + async def list_changesets(self, *, limit: int = 50) -> list[ChangeSetRead]: + stmt = ( + select(WorkspaceSyncChangeSet) + .where( + WorkspaceSyncChangeSet.workspace_id == self.workspace_id, + WorkspaceSyncChangeSet.provider == SyncProvider.GIT.value, + ) + .order_by(desc(WorkspaceSyncChangeSet.created_at)) + .limit(limit) + ) + result = await self.session.execute(stmt) + return [self._changeset_to_read(row) for row in result.scalars().all()] + + async def get_changeset(self, changeset_id: uuid.UUID) -> ChangeSetRead: + changeset = await self._get_changeset(changeset_id) + return self._changeset_to_read(changeset) + + async def export_changeset( + self, + *, + changeset_id: uuid.UUID, + params: ChangeSetExport, + ) -> WorkspaceSyncExportResult: + """Materialize a ChangeSet into a Git branch and optional pull request.""" + changeset = await self._get_changeset(changeset_id) + selected_files = self._changeset_rendered_files(changeset) + url = await self._workspace_git_url() + + git_svc = WorkspaceGitHubSyncService(session=self.session, role=self.role) + commit = await git_svc.write_files( + url=url, + files=selected_files, + message=params.message, + branch=params.branch, + create_pr=params.create_pr, + pr_base_branch=params.pr_base_branch, + ) + + materialization = WorkspaceSyncMaterialization( + workspace_id=self.workspace_id, + changeset_id=changeset.id, + provider=SyncProvider.GIT.value, + branch=commit.ref, + base_ref=commit.base_ref, + pr_number=commit.pr_number, + pr_url=commit.pr_url, + commit_shas=[commit.sha] if commit.sha else [], + status=( + MaterializationStatus.COMMITTED.value + if commit.sha + else MaterializationStatus.NO_OP.value + ), + ) + changeset.status = ChangeSetStatus.EXPORTED.value + self.session.add_all([materialization, changeset]) + await self.session.commit() + return WorkspaceSyncExportResult(changeset_id=changeset.id, commit=commit) + + async def export_workflow( + self, + *, + workflow: Workflow, + dsl: DSLInput, + options: PushOptions, + ) -> WorkspaceSyncExportResult: + if not options.branch: + raise ValueError("branch is required for workspace sync export") + + url = await self._workspace_git_url() + await self.session.refresh( + workflow, + ["tags", "folder", "schedules", "webhook", "case_trigger"], + ) + preferred_source_id = default_workflow_source_id( + alias=workflow.alias, + title=dsl.title, + ) + mapping = await self._ensure_resource_mapping( + resource_type=SyncResourceType.WORKFLOW.value, + local_id=WorkflowUUID.new(workflow.id), + preferred_source_id=preferred_source_id, + source_path=workflow_source_path(preferred_source_id), + create=True, + reserved_source_ids=set(), + ) + if mapping is None: + raise RuntimeError("Expected workflow source mapping to be created") + + workflow_spec = workflow_spec_from_orm( + workflow, + dsl=dsl, + source_id=mapping.source_id, + ) + manifest = WorkspaceManifest() + selected_spec = WorkspaceSpec(workflows={workflow_spec.id: workflow_spec}) + selected_files = self._files_from_spec(manifest=manifest, spec=selected_spec) + changeset = await self._create_changeset_for_specs( + title=options.message, + description=None, + specs=[workflow_spec], + selected_files=selected_files, + ) + + git_svc = WorkspaceGitHubSyncService(session=self.session, role=self.role) + commit = await git_svc.write_files( + url=url, + files=selected_files, + message=options.message, + branch=options.branch, + create_pr=options.create_pr, + pr_base_branch=options.pr_base_branch, + ) + + materialization = WorkspaceSyncMaterialization( + workspace_id=self.workspace_id, + changeset_id=changeset.id, + provider=SyncProvider.GIT.value, + branch=commit.ref, + base_ref=commit.base_ref, + pr_number=commit.pr_number, + pr_url=commit.pr_url, + commit_shas=[commit.sha] if commit.sha else [], + status=( + MaterializationStatus.COMMITTED.value + if commit.sha + else MaterializationStatus.NO_OP.value + ), + ) + changeset.status = ChangeSetStatus.EXPORTED.value + mapping.source_path = workflow_source_path(workflow_spec.id) + mapping.last_synced_commit_sha = commit.sha + mapping.last_synced_spec_hash = stable_hash(workflow_spec) + mapping.sync_status = ResourceSyncStatus.SYNCED.value + workflow.git_sync_branch = commit.ref + self.session.add_all([materialization, changeset, mapping, workflow]) + await self.session.commit() + return WorkspaceSyncExportResult(changeset_id=changeset.id, commit=commit) + + async def export_workflow_publish_result( + self, + *, + workflow: Workflow, + dsl: DSLInput, + options: PushOptions, + ) -> WorkflowDslPublishResult: + result = await self.export_workflow(workflow=workflow, dsl=dsl, options=options) + return result.as_workflow_publish_result() + + async def list_commits( + self, + *, + url: GitUrl, + branch: str = "main", + limit: int = 10, + ) -> list[Any]: + return await WorkspaceGitHubSyncService( + session=self.session, + role=self.role, + ).list_commits(url=url, branch=branch, limit=limit) + + async def list_branches(self, *, url: GitUrl, limit: int = 100) -> list[Any]: + return await WorkspaceGitHubSyncService( + session=self.session, + role=self.role, + ).list_branches(url=url, limit=limit) + + async def _list_projectable_workflows( + self, + *, + workflow_ids: Sequence[WorkflowUUID] | None, + ) -> list[Workflow]: + stmt = ( + select(Workflow) + .where(Workflow.workspace_id == self.workspace_id) + .order_by(Workflow.created_at, Workflow.id) + ) + if workflow_ids: + stmt = stmt.where(Workflow.id.in_(list(workflow_ids))) + result = await self.session.execute(stmt) + return list(result.scalars().all()) + + async def _get_workflow_dsl( + self, + workflow: Workflow, + *, + defn_service: WorkflowDefinitionsService, + mgmt_service: WorkflowsManagementService, + ) -> DSLInput: + definition = await defn_service.get_definition_by_workflow_id( + WorkflowUUID.new(workflow.id) + ) + if definition and definition.content: + return DSLInput.model_validate(definition.content) + return await mgmt_service.build_dsl_from_workflow(workflow) + + async def _ensure_resource_mapping( + self, + *, + resource_type: str, + local_id: uuid.UUID, + preferred_source_id: str, + source_path: str, + create: bool, + reserved_source_ids: set[str], + ) -> WorkspaceSyncResourceMapping | None: + stmt = select(WorkspaceSyncResourceMapping).where( + WorkspaceSyncResourceMapping.workspace_id == self.workspace_id, + WorkspaceSyncResourceMapping.provider == SyncProvider.GIT.value, + WorkspaceSyncResourceMapping.resource_type == resource_type, + WorkspaceSyncResourceMapping.local_id == local_id, + ) + if mapping := (await self.session.execute(stmt)).scalar_one_or_none(): + return mapping + if not create: + return None + + source_id = await self._unique_source_id( + resource_type=resource_type, + preferred_source_id=preferred_source_id, + reserved_source_ids=reserved_source_ids, + ) + mapping = WorkspaceSyncResourceMapping( + workspace_id=self.workspace_id, + provider=SyncProvider.GIT.value, + resource_type=resource_type, + source_id=source_id, + source_path=workflow_source_path(source_id) + if resource_type == SyncResourceType.WORKFLOW.value + else source_path, + local_id=local_id, + sync_status=ResourceSyncStatus.UNTRACKED.value, + ) + self.session.add(mapping) + await self.session.flush() + return mapping + + async def _unique_source_id( + self, + *, + resource_type: str, + preferred_source_id: str, + reserved_source_ids: set[str], + ) -> str: + base = preferred_source_id + counter = 2 + candidate = base + while candidate in reserved_source_ids or await self._source_id_exists( + resource_type=resource_type, + source_id=candidate, + ): + candidate = f"{base}-{counter}" + counter += 1 + return candidate + + async def _source_id_exists(self, *, resource_type: str, source_id: str) -> bool: + stmt = select(WorkspaceSyncResourceMapping.id).where( + WorkspaceSyncResourceMapping.workspace_id == self.workspace_id, + WorkspaceSyncResourceMapping.provider == SyncProvider.GIT.value, + WorkspaceSyncResourceMapping.resource_type == resource_type, + WorkspaceSyncResourceMapping.source_id == source_id, + ) + return (await self.session.execute(stmt)).scalar_one_or_none() is not None + + def _files_from_spec( + self, + *, + manifest: WorkspaceManifest, + spec: WorkspaceSpec, + ) -> dict[str, str]: + files = {MANIFEST_FILENAME: canonical_json_text(manifest)} + for source_id, workflow_spec in sorted(spec.workflows.items()): + files[workflow_source_path(source_id)] = serialize_workflow_spec( + workflow_spec + ) + return files + + def _files_for_workflow_specs( + self, + specs: list[WorkflowResourceSpec], + ) -> dict[str, str]: + manifest = WorkspaceManifest() + selected_spec = WorkspaceSpec( + workflows={ + spec.id: spec for spec in sorted(specs, key=lambda item: item.id) + } + ) + return self._files_from_spec(manifest=manifest, spec=selected_spec) + + async def _reconcile_workflow_specs( + self, + *, + spec: WorkspaceSpec, + commit_sha: str, + ) -> PullResult: + local_ids: dict[str, WorkflowUUID] = {} + remote_workflows = [] + for source_id, workflow_spec in sorted(spec.workflows.items()): + local_id = await self._resolve_local_workflow_id(source_id) + local_ids[source_id] = local_id + remote_workflows.append( + workflow_spec_to_remote(workflow_spec, local_workflow_id=local_id) + ) + + result = await WorkflowImportService( + session=self.session, + role=self.role, + ).import_workflows_atomic(remote_workflows, commit_sha=commit_sha) + if not result.success: + return result + + for source_id, workflow_spec in sorted(spec.workflows.items()): + await self._upsert_remote_mapping( + source_id=source_id, + source_path=workflow_source_path(source_id), + local_id=local_ids[source_id], + commit_sha=commit_sha, + spec_hash=stable_hash(workflow_spec), + ) + await self.session.commit() + return result + + async def _resolve_local_workflow_id(self, source_id: str) -> WorkflowUUID: + stmt = select(WorkspaceSyncResourceMapping).where( + WorkspaceSyncResourceMapping.workspace_id == self.workspace_id, + WorkspaceSyncResourceMapping.provider == SyncProvider.GIT.value, + WorkspaceSyncResourceMapping.resource_type + == SyncResourceType.WORKFLOW.value, + WorkspaceSyncResourceMapping.source_id == source_id, + ) + if mapping := (await self.session.execute(stmt)).scalar_one_or_none(): + return WorkflowUUID.new(mapping.local_id) + + try: + legacy_id = WorkflowUUID.new(source_id) + except ValueError: + return WorkflowUUID.new_uuid4() + + workflow = await self.session.scalar( + select(Workflow).where( + Workflow.workspace_id == self.workspace_id, + Workflow.id == legacy_id, + ) + ) + return WorkflowUUID.new(workflow.id) if workflow else WorkflowUUID.new_uuid4() + + async def _upsert_remote_mapping( + self, + *, + source_id: str, + source_path: str, + local_id: WorkflowUUID, + commit_sha: str, + spec_hash: str, + ) -> WorkspaceSyncResourceMapping: + stmt = select(WorkspaceSyncResourceMapping).where( + WorkspaceSyncResourceMapping.workspace_id == self.workspace_id, + WorkspaceSyncResourceMapping.provider == SyncProvider.GIT.value, + WorkspaceSyncResourceMapping.resource_type + == SyncResourceType.WORKFLOW.value, + WorkspaceSyncResourceMapping.source_id == source_id, + ) + mapping = (await self.session.execute(stmt)).scalar_one_or_none() + if mapping is None: + mapping = WorkspaceSyncResourceMapping( + workspace_id=self.workspace_id, + provider=SyncProvider.GIT.value, + resource_type=SyncResourceType.WORKFLOW.value, + source_id=source_id, + local_id=local_id, + ) + mapping.source_path = source_path + mapping.last_synced_commit_sha = commit_sha + mapping.last_synced_spec_hash = spec_hash + mapping.sync_status = ResourceSyncStatus.SYNCED.value + self.session.add(mapping) + return mapping + + async def _record_successful_pull( + self, + *, + state: WorkspaceSyncState, + snapshot: WorkspaceRemoteSnapshot, + url: GitUrl, + ) -> None: + state.repo_url = url.to_url() + state.target_ref = url.ref or state.target_ref + state.base_commit_sha = snapshot.commit_sha + state.base_tree_sha = snapshot.tree_sha + state.base_spec_hash = snapshot.spec_hash + state.last_remote_commit_sha = snapshot.commit_sha + state.last_remote_tree_sha = snapshot.tree_sha + state.status = SyncStateStatus.CLEAN.value + state.last_direction = SyncDirection.PULL.value + state.last_synced_at = datetime.now(UTC) + state.last_error = None + self.session.add(state) + await self.session.commit() + + async def _read_remote_snapshot( + self, + *, + url: GitUrl, + ref: str, + ) -> tuple[WorkspaceRemoteSnapshot, list[PullDiagnostic]]: + git_svc = WorkspaceGitHubSyncService(session=self.session, role=self.role) + remote_tree = await git_svc.read_files(url=url, ref=ref) + return await self.parse_files( + remote_tree.files, + commit_sha=remote_tree.commit_sha, + tree_sha=remote_tree.tree_sha, + ) + + async def _pending_changes_from_projection( + self, + *, + projection: WorkspaceProjection, + state: WorkspaceSyncState, + ) -> WorkspaceSyncPendingChanges: + mappings = await self._workflow_mappings() + mappings_by_source_id = {mapping.source_id: mapping for mapping in mappings} + changes: list[WorkspaceSyncPendingChange] = [] + + for source_id, spec in sorted(projection.spec.workflows.items()): + mapping = mappings_by_source_id.get(source_id) + before_hash = mapping.last_synced_spec_hash if mapping else None + after_hash = stable_hash(spec) + if before_hash == after_hash: + continue + operation = ( + SyncOperation.CREATE if before_hash is None else SyncOperation.UPDATE + ) + changes.append( + WorkspaceSyncPendingChange( + resource_type=SyncResourceType.WORKFLOW.value, + source_id=source_id, + source_path=workflow_source_path(source_id), + local_id=mapping.local_id if mapping else None, + operation=operation, + title=spec.definition.title, + alias=spec.alias, + before_spec_hash=before_hash, + after_spec_hash=after_hash, + exportable=True, + ) + ) + + return WorkspaceSyncPendingChanges( + base_spec_hash=state.base_spec_hash, + local_spec_hash=projection.spec_hash, + changes=changes, + ) + + async def _remote_changed_source_ids( + self, + remote_snapshot: WorkspaceRemoteSnapshot | None, + ) -> set[str]: + if remote_snapshot is None: + return set() + + mappings = await self._workflow_mappings() + mappings_by_source_id = {mapping.source_id: mapping for mapping in mappings} + remote_source_ids = set(remote_snapshot.spec.workflows) + changed: set[str] = set() + + for source_id, spec in remote_snapshot.spec.workflows.items(): + mapping = mappings_by_source_id.get(source_id) + remote_hash = stable_hash(spec) + if mapping is None or mapping.last_synced_spec_hash != remote_hash: + changed.add(source_id) + + for mapping in mappings: + if ( + mapping.last_synced_spec_hash + and mapping.source_id not in remote_source_ids + ): + changed.add(mapping.source_id) + + return changed + + def _classify_status( + self, + *, + base_spec_hash: str | None, + local_spec_hash: str, + remote_spec_hash: str | None, + local_changed_source_ids: set[str], + remote_changed_source_ids: set[str], + remote_diagnostics: list[PullDiagnostic], + ) -> SyncStateStatus: + if remote_diagnostics: + return SyncStateStatus.ERROR + if base_spec_hash is None or remote_spec_hash is None: + return SyncStateStatus.NEVER_SYNCED + + local_matches_base = local_spec_hash == base_spec_hash + remote_matches_base = remote_spec_hash == base_spec_hash + if local_matches_base and remote_matches_base: + return SyncStateStatus.CLEAN + if not local_matches_base and remote_matches_base: + return SyncStateStatus.LOCAL_DIRTY + if local_matches_base and not remote_matches_base: + return SyncStateStatus.REMOTE_AHEAD + if local_changed_source_ids & remote_changed_source_ids: + return SyncStateStatus.CONFLICTED + return SyncStateStatus.DIVERGED + + async def _workflow_mappings(self) -> list[WorkspaceSyncResourceMapping]: + stmt = select(WorkspaceSyncResourceMapping).where( + WorkspaceSyncResourceMapping.workspace_id == self.workspace_id, + WorkspaceSyncResourceMapping.provider == SyncProvider.GIT.value, + WorkspaceSyncResourceMapping.resource_type + == SyncResourceType.WORKFLOW.value, + ) + result = await self.session.execute(stmt) + return list(result.scalars().all()) + + def _select_workflow_specs( + self, + *, + projection: WorkspaceProjection, + resources: list[ResourceRef], + ) -> list[WorkflowResourceSpec]: + specs: list[WorkflowResourceSpec] = [] + seen_source_ids: set[str] = set() + for resource in resources: + if resource.resource_type != SyncResourceType.WORKFLOW.value: + raise TracecatValidationError( + f"Unsupported sync resource type: {resource.resource_type}" + ) + if resource.source_id in seen_source_ids: + continue + spec = projection.spec.workflows.get(resource.source_id) + if spec is None: + raise TracecatValidationError( + f"Workflow source id not found in local projection: {resource.source_id}" + ) + specs.append(spec) + seen_source_ids.add(resource.source_id) + return specs + + async def _create_changeset_for_specs( + self, + *, + title: str, + description: str | None, + specs: list[WorkflowResourceSpec], + selected_files: dict[str, str], + ) -> WorkspaceSyncChangeSet: + state = await self._get_or_create_state(url=await self._workspace_git_url()) + mappings_by_source_id = { + mapping.source_id: mapping for mapping in await self._workflow_mappings() + } + selected_resources = [ + { + "resource_type": SyncResourceType.WORKFLOW.value, + "source_id": spec.id, + "source_path": workflow_source_path(spec.id), + "local_id": str(mappings_by_source_id[spec.id].local_id) + if spec.id in mappings_by_source_id + else None, + } + for spec in specs + ] + changeset = WorkspaceSyncChangeSet( + workspace_id=self.workspace_id, + provider=SyncProvider.GIT.value, + title=title, + description=description, + base_commit_sha=state.base_commit_sha, + base_spec_hash=state.base_spec_hash, + selected_resources=selected_resources, + selected_paths=sorted(selected_files), + rendered_files=dict(sorted(selected_files.items())), + validation_status=ValidationStatus.VALID.value, + validation_result={}, + status=ChangeSetStatus.VALIDATED.value, + created_by=self.role.user_id, + ) + self.session.add(changeset) + await self.session.flush() + for spec in specs: + mapping = mappings_by_source_id.get(spec.id) + item = WorkspaceSyncChangeSetItem( + workspace_id=self.workspace_id, + changeset_id=changeset.id, + resource_type=SyncResourceType.WORKFLOW.value, + source_id=spec.id, + source_path=workflow_source_path(spec.id), + local_id=mapping.local_id if mapping else None, + operation=( + SyncOperation.CREATE.value + if mapping is None or mapping.last_synced_spec_hash is None + else SyncOperation.UPDATE.value + ), + spec_hash=stable_hash(spec), + dependencies=[], + ) + self.session.add(item) + await self.session.flush() + return changeset + + def _changeset_rendered_files( + self, + changeset: WorkspaceSyncChangeSet, + ) -> dict[str, str]: + rendered_files = changeset.rendered_files or {} + if not rendered_files: + raise TracecatValidationError( + "Workspace sync ChangeSet has no frozen files. Recreate the ChangeSet before exporting." + ) + if any( + not isinstance(path, str) or not isinstance(content, str) + for path, content in rendered_files.items() + ): + raise TracecatValidationError( + "Workspace sync ChangeSet frozen files are invalid." + ) + return dict(sorted(rendered_files.items())) + + async def _get_changeset( + self, + changeset_id: uuid.UUID, + ) -> WorkspaceSyncChangeSet: + stmt = select(WorkspaceSyncChangeSet).where( + WorkspaceSyncChangeSet.workspace_id == self.workspace_id, + WorkspaceSyncChangeSet.provider == SyncProvider.GIT.value, + WorkspaceSyncChangeSet.id == changeset_id, + ) + changeset = (await self.session.execute(stmt)).scalar_one_or_none() + if changeset is None: + raise TracecatNotFoundError("Workspace sync ChangeSet not found") + return changeset + + def _changeset_to_read(self, changeset: WorkspaceSyncChangeSet) -> ChangeSetRead: + return ChangeSetRead( + id=changeset.id, + title=changeset.title, + description=changeset.description, + base_commit_sha=changeset.base_commit_sha, + base_spec_hash=changeset.base_spec_hash, + selected_resources=changeset.selected_resources, + selected_paths=changeset.selected_paths, + validation_status=changeset.validation_status, + validation_result=changeset.validation_result, + status=changeset.status, + ) + + async def _get_or_create_state(self, *, url: GitUrl) -> WorkspaceSyncState: + if state := await self._get_state(url=url): + return state + repo_url = url.to_url() + target_ref = url.ref or "main" + insert_stmt = ( + insert(WorkspaceSyncState) + .values( + workspace_id=self.workspace_id, + provider=SyncProvider.GIT.value, + repo_url=repo_url, + target_ref=target_ref, + status=SyncStateStatus.NEVER_SYNCED.value, + ) + .on_conflict_do_nothing( + constraint="uq_workspace_sync_state_workspace_provider_repo_ref" + ) + ) + await self.session.execute(insert_stmt) + state = await self._get_state(url=url) + if state is None: + raise RuntimeError("Workspace sync state was not created") + return state + + async def _get_state(self, *, url: GitUrl) -> WorkspaceSyncState | None: + repo_url = url.to_url() + target_ref = url.ref or "main" + stmt = select(WorkspaceSyncState).where( + WorkspaceSyncState.workspace_id == self.workspace_id, + WorkspaceSyncState.provider == SyncProvider.GIT.value, + WorkspaceSyncState.repo_url == repo_url, + WorkspaceSyncState.target_ref == target_ref, + ) + return (await self.session.execute(stmt)).scalar_one_or_none() + + def _unsaved_state(self, *, url: GitUrl) -> WorkspaceSyncState: + return WorkspaceSyncState( + workspace_id=self.workspace_id, + provider=SyncProvider.GIT.value, + repo_url=url.to_url(), + target_ref=url.ref or "main", + status=SyncStateStatus.NEVER_SYNCED.value, + ) + + async def _workspace_git_url(self) -> GitUrl: + workspace = await self._workspace() + repo_url = ( + workspace.settings.get("git_repo_url") if workspace.settings else None + ) + if not repo_url: + raise TracecatSettingsError( + "Git repository URL not configured for this workspace." + ) + try: + return parse_git_url(repo_url, allowed_domains={"github.com"}) + except ValueError as e: + raise TracecatSettingsError( + f"Invalid Git repository URL configured for this workspace: {e}" + ) from e + + async def _workspace(self) -> Workspace: + workspace = await WorkspaceService( + session=self.session, + role=self.role, + ).get_workspace(self.workspace_id) + if workspace is None: + raise TracecatNotFoundError("Workspace not found") + return workspace diff --git a/tracecat/workspace_sync/workflow.py b/tracecat/workspace_sync/workflow.py new file mode 100644 index 0000000000..556033760e --- /dev/null +++ b/tracecat/workspace_sync/workflow.py @@ -0,0 +1,203 @@ +"""Workflow resource adapter for workspace Git sync.""" + +from __future__ import annotations + +from typing import Any, cast + +import yaml +from pydantic import ValidationError +from slugify import slugify + +from tracecat.cases.enums import CaseEventType +from tracecat.db.models import Workflow +from tracecat.dsl.common import DSLInput +from tracecat.identifiers.workflow import WorkflowUUID +from tracecat.sync import PullDiagnostic +from tracecat.workflow.case_triggers.schemas import is_case_trigger_configured +from tracecat.workflow.store.schemas import ( + RemoteCaseTrigger, + RemoteWebhook, + RemoteWorkflowDefinition, + RemoteWorkflowSchedule, + RemoteWorkflowTag, + Status, +) +from tracecat.workspace_sync.schemas import ( + WORKFLOW_DEFINITION_FILENAME, + WORKFLOW_ROOT, + WorkflowResourceSpec, +) + + +def workflow_source_path(source_id: str) -> str: + return f"{WORKFLOW_ROOT}/{source_id}/{WORKFLOW_DEFINITION_FILENAME}" + + +def workflow_source_id_from_path(path: str) -> str | None: + parts = path.strip("/").split("/") + if len(parts) != 3: + return None + root, source_id, filename = parts + if root != WORKFLOW_ROOT or filename != WORKFLOW_DEFINITION_FILENAME: + return None + return source_id or None + + +def default_workflow_source_id(*, alias: str | None, title: str) -> str: + base = slugify(alias or title, separator="-") or "workflow" + return base[:96].strip("-") or "workflow" + + +def workflow_spec_from_orm( + workflow: Workflow, + *, + dsl: DSLInput, + source_id: str, +) -> WorkflowResourceSpec: + folder_path = workflow.folder.path if workflow.folder else None + webhook = workflow.webhook + + case_trigger = None + if workflow.case_trigger and is_case_trigger_configured( + status=workflow.case_trigger.status, + event_types=workflow.case_trigger.event_types, + tag_filters=workflow.case_trigger.tag_filters, + ): + case_trigger = RemoteCaseTrigger( + status=cast(Status, workflow.case_trigger.status), + event_types=[ + CaseEventType(event_type) + for event_type in workflow.case_trigger.event_types + ], + tag_filters=workflow.case_trigger.tag_filters, + ) + + return WorkflowResourceSpec( + id=source_id, + alias=workflow.alias, + folder_path=folder_path, + tags=[RemoteWorkflowTag(name=t.name) for t in workflow.tags] or None, + schedules=[ + RemoteWorkflowSchedule( + status=cast(Status, s.status), + cron=s.cron, + every=s.every, + offset=s.offset, + start_at=s.start_at, + end_at=s.end_at, + timeout=s.timeout, + ) + for s in (workflow.schedules or []) + ] + or None, + webhook=RemoteWebhook( + methods=webhook.methods, + status=cast(Status, webhook.status), + include_headers=webhook.include_headers, + ) + if webhook + else None, + case_trigger=case_trigger, + definition=dsl, + ) + + +def workflow_spec_to_remote( + spec: WorkflowResourceSpec, *, local_workflow_id: WorkflowUUID +) -> RemoteWorkflowDefinition: + return RemoteWorkflowDefinition( + id=local_workflow_id.short(), + alias=spec.alias, + folder_path=spec.folder_path, + tags=spec.tags, + schedules=spec.schedules, + webhook=spec.webhook, + case_trigger=spec.case_trigger, + definition=spec.definition, + ) + + +def workflow_spec_from_legacy( + remote: RemoteWorkflowDefinition, + *, + source_id: str | None = None, +) -> WorkflowResourceSpec: + return WorkflowResourceSpec( + id=source_id or remote.id, + alias=remote.alias, + folder_path=remote.folder_path, + tags=remote.tags, + schedules=remote.schedules, + webhook=remote.webhook, + case_trigger=remote.case_trigger, + definition=remote.definition, + ) + + +def serialize_workflow_spec(spec: WorkflowResourceSpec) -> str: + data = spec.model_dump(mode="json", exclude_none=True) + return yaml.safe_dump(data, sort_keys=False, allow_unicode=True) + + +def parse_workflow_spec( + path: str, content: str +) -> tuple[WorkflowResourceSpec | None, PullDiagnostic | None]: + source_id = workflow_source_id_from_path(path) + yaml_data: dict[str, Any] | None = None + try: + raw = yaml.safe_load(content) + if not isinstance(raw, dict) or not raw: + return None, PullDiagnostic( + workflow_path=path, + workflow_title=None, + error_type="parse", + message="Empty or invalid workflow YAML file", + details={}, + ) + yaml_data = raw + + if raw.get("type") == "workflow" and raw.get("version") == 1: + if "id" not in raw and source_id is not None: + raw = {**raw, "id": source_id} + spec = WorkflowResourceSpec.model_validate(raw) + if source_id is not None and spec.id != source_id: + return None, PullDiagnostic( + workflow_path=path, + workflow_title=spec.definition.title, + error_type="validation", + message="Workflow source id does not match its repository path", + details={"path_source_id": source_id, "spec_id": spec.id}, + ) + return spec, None + + legacy = RemoteWorkflowDefinition.model_validate(raw) + return workflow_spec_from_legacy(legacy, source_id=source_id), None + except yaml.YAMLError as e: + return None, PullDiagnostic( + workflow_path=path, + workflow_title=None, + error_type="parse", + message=f"YAML parsing error: {str(e)}", + details={"yaml_error": str(e)}, + ) + except ValidationError as e: + workflow_title = ( + yaml_data.get("definition", {}).get("title") + if isinstance(yaml_data, dict) + else None + ) + return None, PullDiagnostic( + workflow_path=path, + workflow_title=workflow_title, + error_type="validation", + message=f"Validation error: {str(e)}", + details={"validation_errors": e.errors()}, + ) + except Exception as e: + return None, PullDiagnostic( + workflow_path=path, + workflow_title=None, + error_type="parse", + message=f"Unexpected parsing error: {str(e)}", + details={"error": str(e)}, + )