diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 4eb47c964..ce9b01364 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -2,3 +2,4 @@ # https://docs.github.com/en/github/creating-cloning-and-archiving-repositories/about-code-owners#codeowners-syntax * @temporalio/sdk +/contrib/workflowstreams/ @temporalio/sdk @temporalio/ai-sdk diff --git a/contrib/workflowstreams/README.md b/contrib/workflowstreams/README.md new file mode 100644 index 000000000..052af67b4 --- /dev/null +++ b/contrib/workflowstreams/README.md @@ -0,0 +1,152 @@ +# Workflow Streams + +A durable publish/subscribe log hosted inside a Temporal Workflow. + +External code (activities, starters, other workflows) publishes messages to +named topics via **signals**; subscribers long-poll for new items via +**updates**; a **query** exposes the current offset. The stream is backed by +Temporal's durable execution, giving ordered, durable, exactly-once delivery +with client-side batching, publisher dedup, continue-as-new survival, +truncation, and ~1 MB response paging. + +It is well suited to durable event streams whose cost scales with durable +batches rather than message count. Each poll round-trip costs ~100 ms of +latency, so it is not intended for ultra-low-latency streaming. + +## Workflow side + +Construct a `WorkflowStream` once at the start of your workflow. The constructor +registers the publish signal, poll update, and offset query handlers. + +```go +type MyInput struct { + ItemsProcessed int // your own workflow state + StreamState *workflowstreams.WorkflowStreamState +} + +func MyWorkflow(ctx workflow.Context, input MyInput) error { + stream, err := workflowstreams.NewWorkflowStream(ctx, input.StreamState) + if err != nil { + return err + } + + // Optionally publish from workflow code: + if err := stream.Topic("events").Publish("hello from the workflow"); err != nil { + return err + } + + // Run your workflow; the stream serves external publishers and subscribers + // for as long as the workflow is running. Block until your workflow's exit + // condition is met (here, a `done` flag set elsewhere, e.g. by a signal). + return workflow.Await(ctx, func() bool { return done }) +} +``` + +For workflows that use continue-as-new, the stream's log and offsets must be +carried across each boundary, since continue-as-new starts a fresh run with an +empty history. This is a round-trip with two halves: + +- **Capture** the state when rolling over. Instead of returning a plain + `workflow.NewContinueAsNewError`, return `stream.NewContinueAsNewError`. It + snapshots the current stream state and hands it to your callback, which builds + the argument list for the next run. The callback is where you assemble the + full input — carry forward your own workflow state alongside the captured + `state`: + + ```go + return stream.NewContinueAsNewError(ctx, MyWorkflow, func(state *workflowstreams.WorkflowStreamState) []any { + return []any{MyInput{ + ItemsProcessed: itemsProcessed, // your own state, carried across the boundary + StreamState: state, // the captured stream state + }} + }) + ``` + +- **Restore** it on the next run. That `MyInput` arrives as the next run's input, + and its `StreamState` field is the value already passed to `NewWorkflowStream` in the + example above. It is `nil` on a fresh start and non-nil after a roll-over, so + `NewWorkflowStream` rehydrates the log automatically. + +The `*workflowstreams.WorkflowStreamState` field is what gives the captured +stream state somewhere to live between runs; the other fields on `MyInput` are +your own and are threaded through the same way. + +## Publishing (client side) + +From an activity, use `NewClientFromActivity` to target the parent workflow: + +```go +func PublishActivity(ctx context.Context) error { + c, err := workflowstreams.NewClientFromActivity(ctx, workflowstreams.Options{}) + if err != nil { + return err + } + defer c.Close(ctx) // Flush the remaining buffer + + topic := c.Topic("events") + for i := range 100 { + topic.Publish(fmt.Sprintf("item %d", i), false) + } + return nil +} +``` + +From a starter or any code with a `client.Client`, use `NewClient` with an +explicit workflow ID: + +```go +c := workflowstreams.NewClient(temporalClient, workflowID, workflowstreams.Options{}) +defer c.Close(ctx) +c.Topic("events").Publish("from outside", true /* forceFlush */) +``` + +Items are buffered and flushed automatically every `BatchInterval` (default 2s), +when the buffer reaches `MaxBatchSize`, on `forceFlush`, on an explicit +`Flush`, or on `Close`. + +## Subscribing + +`Subscribe` returns a range-over-func iterator: + +```go +for item, err := range c.Subscribe(ctx, workflowstreams.SubscribeOptions{ + Topics: []string{"events"}, // empty/nil = all topics +}) { + if err != nil { + return err + } + var s string + if err := converter.GetDefaultDataConverter().FromPayload(item.Data, &s); err != nil { + return err + } + fmt.Printf("offset=%d topic=%s value=%s\n", item.Offset, item.Topic, s) +} +``` + +Breaking out of the loop or cancelling `ctx` stops the subscription and tears +down the poll loop. The iterator ends cleanly when the workflow reaches a +terminal state, automatically follows continue-as-new chains, and recovers from +truncation by restarting from the current base offset. + +Items yield the raw `*commonpb.Payload`; decode at the call site with your data +converter. Offsets are **global** (across all topics), not per-topic. + +## Options + +| Option | Default | Meaning | +| --- | --- | --- | +| `BatchInterval` | 2s | Automatic flush interval | +| `MaxBatchSize` | unset | Flush once the buffer reaches this size | +| `MaxRetryDuration` | 10m | Max time to retry a failed flush before `FlushTimeoutError`. Must be < the workflow's publisher TTL (15m) to preserve exactly-once delivery | +| `PayloadConverters` | default set | Per-item serialization. Payload conversion only — the client's codec chain runs once on the envelope, never per item | +| `SubscribeOptions.PollCooldown` | 100ms | Min interval between polls | + +## Cross-language protocol + +The handler names (`PublishSignalName`, `PollUpdateName`, `OffsetQueryName`), +the JSON envelope field names, and the per-item payload encoding (base64 of the +marshaled `temporal.api.common.v1.Payload`) match the Python and TypeScript +packages exactly, so a Go publisher or subscriber interoperates with a +Python/TypeScript workflow and vice versa. The data converter codec chain +(encryption, compression) runs once on the signal/update envelope — never per +item — so payloads are not double-encoded. diff --git a/contrib/workflowstreams/client.go b/contrib/workflowstreams/client.go new file mode 100644 index 000000000..a99999fdd --- /dev/null +++ b/contrib/workflowstreams/client.go @@ -0,0 +1,285 @@ +package workflowstreams + +import ( + "context" + "errors" + "iter" + "sync" + "time" + + enumspb "go.temporal.io/api/enums/v1" + "go.temporal.io/sdk/activity" + "go.temporal.io/sdk/client" + "go.temporal.io/sdk/converter" + "go.temporal.io/sdk/temporal" +) + +// Options configures a Client. +type Options struct { + // BatchInterval is the interval between automatic flushes. Default: 2s. + BatchInterval time.Duration + // MaxBatchSize triggers a flush once the buffer reaches this many items. + // Zero disables size-based flushing. + MaxBatchSize int + // MaxRetryDuration is the maximum time to retry a failed flush before + // returning a FlushTimeoutError. Must be less than the workflow's publisher + // TTL (default 15m) to preserve exactly-once delivery. Default: 10m. + MaxRetryDuration time.Duration + // PayloadConverters customize how published values are serialized into the + // per-item Payloads carried inside each batch. They are combined into a + // CompositeDataConverter in the order given (as with + // converter.NewCompositeDataConverter), so the last one should be a + // catch-all such as converter.NewJSONPayloadConverter. + // + // Only payload conversion happens here — never a payload codec + // (encryption, compression). The codec chain configured on the Temporal + // client runs once on the signal/update envelope that carries each batch, + // so encoding items here too would double-encode them; the + // []PayloadConverter type makes that mistake impossible. To decode + // subscribed items, use a converter built from the same PayloadConverters. + // + // Default: the converters from converter.GetDefaultDataConverter(). + PayloadConverters []converter.PayloadConverter +} + +// SubscribeOptions configures a subscription. +type SubscribeOptions struct { + // Topics filters the subscription. Empty or nil means all topics. + Topics []string + // FromOffset is the global offset to start from. Zero means the beginning. + FromOffset int64 + // PollCooldown is the minimum interval between polls when no more items are + // immediately ready. Default: 100ms. + PollCooldown time.Duration +} + +// Client publishes to and subscribes from a workflow stream from external code +// (activities, starters, other workflows). The publish path is owned by an +// internal publisher; the Client itself holds the target workflow and the read +// (subscribe/query) surface. +type Client struct { + c client.Client + workflowID string + followCAN bool + pub *publisher + + mu sync.Mutex + topicHandles map[string]*TopicHandle +} + +// NewClient creates a Client targeting workflowID through the given Temporal +// client. The returned Client follows continue-as-new chains in Subscribe. +func NewClient(c client.Client, workflowID string, opts Options) *Client { + // Build a codec-free converter for per-item serialization. A composite of + // PayloadConverters cannot apply a payload codec, so items are never + // double-encoded against the codec on the client's envelope. + var dc converter.DataConverter = converter.GetDefaultDataConverter() + if len(opts.PayloadConverters) > 0 { + dc = converter.NewCompositeDataConverter(opts.PayloadConverters...) + } + wsc := &Client{ + c: c, + workflowID: workflowID, + followCAN: true, + topicHandles: map[string]*TopicHandle{}, + } + wsc.pub = newPublisher(func(ctx context.Context, in PublishInput) error { + return c.SignalWorkflow(ctx, workflowID, "", PublishSignalName, in) + }, dc, opts) + return wsc +} + +// NewClientFromActivity creates a Client targeting the current activity's parent +// workflow, using the activity's Temporal client. It returns an error if the +// activity has no parent workflow (a standalone activity); in that case use +// NewClient with an explicit workflow ID. +func NewClientFromActivity(ctx context.Context, opts Options) (*Client, error) { + info := activity.GetInfo(ctx) + if info.WorkflowExecution.ID == "" { + return nil, errors.New("workflowstreams: NewClientFromActivity requires an activity scheduled by a workflow; " + + "from a standalone activity use NewClient with an explicit workflow ID") + } + return NewClient(activity.GetClient(ctx), info.WorkflowExecution.ID, opts), nil +} + +// Topic returns a handle for publishing to and subscribing from name. Repeated +// calls with the same name return the same handle. +func (c *Client) Topic(name string) *TopicHandle { + c.mu.Lock() + defer c.mu.Unlock() + if h, ok := c.topicHandles[name]; ok { + return h + } + h := &TopicHandle{name: name, client: c} + c.topicHandles[name] = h + return h +} + +// Flush sends buffered (and pending) items and waits for server confirmation. +// It returns once the items buffered at call time have been signaled to the +// workflow and acknowledged. Returns a FlushTimeoutError if a pending batch +// cannot be sent within MaxRetryDuration. +func (c *Client) Flush(ctx context.Context) error { return c.pub.flush(ctx) } + +// Close stops the background publisher and drains any remaining items. Call it +// (e.g. via defer) to guarantee a final flush. It surfaces any deferred +// FlushTimeoutError from a prior background flush failure. +func (c *Client) Close(ctx context.Context) error { return c.pub.close(ctx) } + +// GetOffset queries the current global offset of the stream. +func (c *Client) GetOffset(ctx context.Context) (int64, error) { + val, err := c.c.QueryWorkflow(ctx, c.workflowID, "", OffsetQueryName) + if err != nil { + return 0, err + } + var n int64 + if err := val.Get(&n); err != nil { + return 0, err + } + return n, nil +} + +// Subscribe returns a range-over-func iterator that long-polls for new items. +// Iterate with: +// +// for item, err := range c.Subscribe(ctx, opts) { +// if err != nil { ... } +// // use item +// } +// +// Breaking out of the loop or cancelling ctx stops the subscription and tears +// down the poll loop. Each yielded item carries the raw *commonpb.Payload in +// Data; decode it with your data converter. The iterator ends cleanly when the +// workflow reaches a terminal state, and automatically follows continue-as-new +// chains. +func (c *Client) Subscribe(ctx context.Context, opts SubscribeOptions) iter.Seq2[WorkflowStreamItem, error] { + return func(yield func(WorkflowStreamItem, error) bool) { + pollCooldown := opts.PollCooldown + if pollCooldown <= 0 { + pollCooldown = defaultPollCooldown + } + topics := opts.Topics + if topics == nil { + topics = []string{} + } + offset := opts.FromOffset + // polledRunID is the run the most recent poll's update was admitted to. + // We capture it before waiting for the update's outcome so that, if that + // run continues-as-new mid-poll (failing the outcome), we still know which + // run to inspect to tell a rollover apart from a terminal end. + var polledRunID string + + for { + if err := ctx.Err(); err != nil { + yield(WorkflowStreamItem{}, err) + return + } + + var result PollResult + // Wait only for ACCEPTED so UpdateWorkflow returns the handle (and its + // run id) as soon as the update is admitted; handle.Get then waits for + // the outcome. With WaitForStage Completed a mid-poll continue-as-new + // would fail UpdateWorkflow with a nil handle, losing the run id. + handle, err := c.c.UpdateWorkflow(ctx, client.UpdateWorkflowOptions{ + WorkflowID: c.workflowID, + UpdateName: PollUpdateName, + Args: []any{PollInput{Topics: topics, FromOffset: offset}}, + WaitForStage: client.WorkflowUpdateStageAccepted, + }) + if err == nil { + polledRunID = handle.RunID() + err = handle.Get(ctx, &result) + } + if err != nil { + var appErr *temporal.ApplicationError + if errors.As(err, &appErr) { + switch appErr.Type() { + case ErrTypeTruncatedOffset: + // Fell behind truncation; restart from the beginning of + // whatever still exists. + offset = 0 + continue + case ErrTypeStreamDraining: + // The workflow is detaching for continue-as-new. Back off + // and retry; the poll lands on the successor run once the + // rollover completes (or the chain/terminal checks below + // fire on a genuine end). + select { + case <-time.After(pollCooldown): + continue + case <-ctx.Done(): + yield(WorkflowStreamItem{}, ctx.Err()) + return + } + } + } + // The workflow may have continued-as-new or completed between + // polls. Follow the chain, exit cleanly on a terminal state, + // otherwise surface the error. + if followed := c.followContinueAsNew(ctx, polledRunID); followed { + continue + } + if c.isTerminal(ctx, polledRunID) { + return + } + yield(WorkflowStreamItem{}, err) + return + } + + for _, wi := range result.Items { + payload, derr := decodePayloadWire(wi.Data) + if derr != nil { + yield(WorkflowStreamItem{}, derr) + return + } + if !yield(WorkflowStreamItem{Topic: wi.Topic, Data: payload, Offset: wi.Offset}, nil) { + return + } + } + offset = result.NextOffset + + if !result.MoreReady { + select { + case <-time.After(pollCooldown): + case <-ctx.Done(): + yield(WorkflowStreamItem{}, ctx.Err()) + return + } + } + } + } +} + +// followContinueAsNew reports whether runID (the run we were polling) rolled +// over to a fresh run via continue-as-new. It describes that specific run: a +// rolled-over run is closed with status CONTINUED_AS_NEW, whereas the latest run +// would report RUNNING, so describing by run id is what makes the check fire. +// The successor run id is not needed — subsequent polls use an empty run id and +// so address the latest run automatically. A blank runID (no poll has been +// admitted yet) falls back to describing the latest run. +func (c *Client) followContinueAsNew(ctx context.Context, runID string) bool { + if !c.followCAN { + return false + } + desc, err := c.c.DescribeWorkflowExecution(ctx, c.workflowID, runID) + if err != nil { + return false + } + return desc.GetWorkflowExecutionInfo().GetStatus() == enumspb.WORKFLOW_EXECUTION_STATUS_CONTINUED_AS_NEW +} + +func (c *Client) isTerminal(ctx context.Context, runID string) bool { + desc, err := c.c.DescribeWorkflowExecution(ctx, c.workflowID, runID) + if err != nil { + return false + } + switch desc.GetWorkflowExecutionInfo().GetStatus() { + case enumspb.WORKFLOW_EXECUTION_STATUS_COMPLETED, + enumspb.WORKFLOW_EXECUTION_STATUS_FAILED, + enumspb.WORKFLOW_EXECUTION_STATUS_CANCELED, + enumspb.WORKFLOW_EXECUTION_STATUS_TERMINATED, + enumspb.WORKFLOW_EXECUTION_STATUS_TIMED_OUT: + return true + } + return false +} diff --git a/contrib/workflowstreams/client_test.go b/contrib/workflowstreams/client_test.go new file mode 100644 index 000000000..6f0889428 --- /dev/null +++ b/contrib/workflowstreams/client_test.go @@ -0,0 +1,187 @@ +package workflowstreams + +import ( + "context" + "errors" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/require" + "go.temporal.io/sdk/client" + "go.temporal.io/sdk/converter" +) + +// fakeClient implements just the client.Client methods Client uses. Embedding +// the interface satisfies the rest; calling an unimplemented method panics, +// which is fine because the tests never exercise them. +type fakeClient struct { + client.Client + + mu sync.Mutex + signals []PublishInput + signalErr error // when set, SignalWorkflow returns this error +} + +func (f *fakeClient) SignalWorkflow(_ context.Context, _, _, signalName string, arg any) error { + f.mu.Lock() + defer f.mu.Unlock() + if f.signalErr != nil { + return f.signalErr + } + if signalName == PublishSignalName { + f.signals = append(f.signals, arg.(PublishInput)) + } + return nil +} + +func (f *fakeClient) recorded() []PublishInput { + f.mu.Lock() + defer f.mu.Unlock() + return append([]PublishInput(nil), f.signals...) +} + +func decodeFirstItem(t *testing.T, in PublishInput, idx int) string { + t.Helper() + payload, err := decodePayloadWire(in.Items[idx].Data) + require.NoError(t, err) + var s string + require.NoError(t, converter.GetDefaultDataConverter().FromPayload(payload, &s)) + return s +} + +func TestFlushSendsBufferedItems(t *testing.T) { + fc := &fakeClient{} + c := NewClient(fc, "wf-1", Options{}) + c.Topic("events").Publish("a", false) + c.Topic("events").Publish("b", false) + + require.NoError(t, c.Flush(context.Background())) + + sigs := fc.recorded() + require.Len(t, sigs, 1) + require.Len(t, sigs[0].Items, 2) + require.EqualValues(t, 1, sigs[0].Sequence) + require.NotEmpty(t, sigs[0].PublisherID) + require.Equal(t, "a", decodeFirstItem(t, sigs[0], 0)) + require.Equal(t, "b", decodeFirstItem(t, sigs[0], 1)) + + require.NoError(t, c.Close(context.Background())) +} + +// TestPayloadConvertersDriveItemConversion proves that the configured +// PayloadConverters (not the default set) serialize each item. With only the +// byte-slice converter, a []byte round-trips but a string has no converter and +// fails — whereas the default set's JSON fallback would have accepted it. +func TestPayloadConvertersDriveItemConversion(t *testing.T) { + fc := &fakeClient{} + c := NewClient(fc, "wf-1", Options{ + PayloadConverters: []converter.PayloadConverter{converter.NewByteSlicePayloadConverter()}, + }) + + c.Topic("events").Publish([]byte("hi"), false) + require.NoError(t, c.Flush(context.Background())) + sigs := fc.recorded() + require.Len(t, sigs, 1) + payload, err := decodePayloadWire(sigs[0].Items[0].Data) + require.NoError(t, err) + require.Equal(t, "binary/plain", string(payload.Metadata[converter.MetadataEncoding]), + "item must be serialized by the configured byte-slice converter") + require.NoError(t, c.Close(context.Background())) + + // A string is unconvertible under the byte-slice-only set, so the flush + // fails — the default set's JSON fallback would have accepted it. Use a + // fresh client since the unconvertible item stays buffered after the error. + c2 := NewClient(&fakeClient{}, "wf-1", Options{ + PayloadConverters: []converter.PayloadConverter{converter.NewByteSlicePayloadConverter()}, + }) + c2.Topic("events").Publish("not-bytes", false) + require.Error(t, c2.Flush(context.Background())) +} + +func TestFlushNoopWhenEmpty(t *testing.T) { + fc := &fakeClient{} + c := NewClient(fc, "wf-1", Options{}) + require.NoError(t, c.Flush(context.Background())) + require.Empty(t, fc.recorded()) +} + +func TestSequenceAdvancesAcrossFlushes(t *testing.T) { + fc := &fakeClient{} + c := NewClient(fc, "wf-1", Options{}) + + c.Topic("t").Publish("x", false) + require.NoError(t, c.Flush(context.Background())) + c.Topic("t").Publish("y", false) + require.NoError(t, c.Flush(context.Background())) + + sigs := fc.recorded() + require.Len(t, sigs, 2) + require.EqualValues(t, 1, sigs[0].Sequence) + require.EqualValues(t, 2, sigs[1].Sequence) + require.Equal(t, sigs[0].PublisherID, sigs[1].PublisherID) + + require.NoError(t, c.Close(context.Background())) +} + +func TestMaxBatchSizeTriggersFlush(t *testing.T) { + fc := &fakeClient{} + // Long interval so only the size threshold can trigger a flush. + c := NewClient(fc, "wf-1", Options{BatchInterval: time.Hour, MaxBatchSize: 2}) + + c.Topic("t").Publish("a", false) + c.Topic("t").Publish("b", false) // reaches MaxBatchSize → flush + + require.Eventually(t, func() bool { + return len(fc.recorded()) == 1 + }, time.Second, 5*time.Millisecond) + + require.NoError(t, c.Close(context.Background())) +} + +func TestCloseDrainsBuffer(t *testing.T) { + fc := &fakeClient{} + c := NewClient(fc, "wf-1", Options{BatchInterval: time.Hour}) + + c.Topic("t").Publish("a", false) + require.NoError(t, c.Close(context.Background())) + + sigs := fc.recorded() + require.Len(t, sigs, 1) + require.Len(t, sigs[0].Items, 1) +} + +func TestForceFlush(t *testing.T) { + fc := &fakeClient{} + c := NewClient(fc, "wf-1", Options{BatchInterval: time.Hour}) + + c.Topic("t").Publish("a", true) // forceFlush + + require.Eventually(t, func() bool { + return len(fc.recorded()) == 1 + }, time.Second, 5*time.Millisecond) + + require.NoError(t, c.Close(context.Background())) +} + +func TestFlushTimeoutAfterMaxRetryDuration(t *testing.T) { + const retryWindow = time.Millisecond + fc := &fakeClient{signalErr: errors.New("boom")} + c := NewClient(fc, "wf-1", Options{BatchInterval: time.Hour, MaxRetryDuration: retryWindow}) + + c.Topic("t").Publish("a", false) + + // The first flush sets pending and fails to send (transient "boom"). + require.Error(t, c.Flush(context.Background())) + + // Wait past the retry window with ample margin for coarse OS timer + // granularity (notably on Windows, where sub-tick durations can read as + // zero). The next flush sees the window exceeded and returns + // FlushTimeoutError. + time.Sleep(50 * time.Millisecond) + + var fte *FlushTimeoutError + require.ErrorAs(t, c.Flush(context.Background()), &fte) + + require.NoError(t, c.Close(context.Background())) +} diff --git a/contrib/workflowstreams/codec.go b/contrib/workflowstreams/codec.go new file mode 100644 index 000000000..f27e2e1be --- /dev/null +++ b/contrib/workflowstreams/codec.go @@ -0,0 +1,38 @@ +package workflowstreams + +import ( + "encoding/base64" + "fmt" + + commonpb "go.temporal.io/api/common/v1" + "google.golang.org/protobuf/proto" +) + +// encodePayloadWire encodes a Payload to the base64-of-proto wire format shared +// across the Go, Python, and TypeScript packages. +func encodePayloadWire(payload *commonpb.Payload) (string, error) { + b, err := proto.Marshal(payload) + if err != nil { + return "", fmt.Errorf("workflowstreams: marshal payload: %w", err) + } + return base64.StdEncoding.EncodeToString(b), nil +} + +// decodePayloadWire decodes the base64-of-proto wire format back to a Payload. +func decodePayloadWire(wire string) (*commonpb.Payload, error) { + b, err := base64.StdEncoding.DecodeString(wire) + if err != nil { + return nil, fmt.Errorf("workflowstreams: decode base64 payload: %w", err) + } + payload := &commonpb.Payload{} + if err := proto.Unmarshal(b, payload); err != nil { + return nil, fmt.Errorf("workflowstreams: unmarshal payload: %w", err) + } + return payload, nil +} + +// payloadWireSize estimates the contribution of a single encoded item to a poll +// response. encoded is already base64 (its on-wire representation). +func payloadWireSize(encoded, topic string) int { + return len(encoded) + len(topic) +} diff --git a/contrib/workflowstreams/codec_test.go b/contrib/workflowstreams/codec_test.go new file mode 100644 index 000000000..776be2c80 --- /dev/null +++ b/contrib/workflowstreams/codec_test.go @@ -0,0 +1,66 @@ +package workflowstreams + +import ( + "encoding/base64" + "testing" + + "github.com/stretchr/testify/require" + commonpb "go.temporal.io/api/common/v1" + "go.temporal.io/sdk/converter" + "google.golang.org/protobuf/proto" +) + +func TestPayloadWireRoundTrip(t *testing.T) { + payload, err := converter.GetDefaultDataConverter().ToPayload("hello") + require.NoError(t, err) + + wire, err := encodePayloadWire(payload) + require.NoError(t, err) + + got, err := decodePayloadWire(wire) + require.NoError(t, err) + require.True(t, proto.Equal(payload, got)) + + // The decoded payload still carries its encoding metadata so a consumer can + // decode it back to the original value. + var s string + require.NoError(t, converter.GetDefaultDataConverter().FromPayload(got, &s)) + require.Equal(t, "hello", s) +} + +func TestPayloadWireFormatIsBase64OfProto(t *testing.T) { + payload := &commonpb.Payload{ + Metadata: map[string][]byte{"encoding": []byte("json/plain")}, + Data: []byte(`"hi"`), + } + wire, err := encodePayloadWire(payload) + require.NoError(t, err) + + // Wire format is base64-of-marshaled-proto; decoding base64 then proto must + // reproduce the payload. This is the contract shared with the Python and + // TypeScript packages. + raw, err := base64.StdEncoding.DecodeString(wire) + require.NoError(t, err) + var decoded commonpb.Payload + require.NoError(t, proto.Unmarshal(raw, &decoded)) + require.True(t, proto.Equal(payload, &decoded)) +} + +func TestDecodePayloadWireRejectsBadBase64(t *testing.T) { + _, err := decodePayloadWire("not valid base64!!!") + require.Error(t, err) +} + +func TestBinaryPayloadRoundTrip(t *testing.T) { + payload, err := converter.GetDefaultDataConverter().ToPayload([]byte{0x00, 0x01, 0xff}) + require.NoError(t, err) + + wire, err := encodePayloadWire(payload) + require.NoError(t, err) + got, err := decodePayloadWire(wire) + require.NoError(t, err) + + var b []byte + require.NoError(t, converter.GetDefaultDataConverter().FromPayload(got, &b)) + require.Equal(t, []byte{0x00, 0x01, 0xff}, b) +} diff --git a/contrib/workflowstreams/doc.go b/contrib/workflowstreams/doc.go new file mode 100644 index 000000000..6db9b3099 --- /dev/null +++ b/contrib/workflowstreams/doc.go @@ -0,0 +1,77 @@ +// Package workflowstreams provides a durable publish/subscribe log hosted +// inside a Temporal workflow. +// +// External code (activities, starters, other workflows) publishes messages to +// named topics via signals; subscribers long-poll for new items via updates; +// a query exposes the current offset. The stream is backed by Temporal's +// durable execution, giving exactly-once, ordered, cross-language delivery with +// client-side batching, publisher dedup, continue-as-new survival, truncation, +// and response paging. +// +// # Workflow side +// +// Construct a [WorkflowStream] once at the start of your workflow function. The +// constructor registers the publish signal, poll update, and offset query +// handlers on the current workflow: +// +// type MyInput struct { +// ItemsProcessed int // your own workflow state +// StreamState *workflowstreams.WorkflowStreamState +// } +// +// func MyWorkflow(ctx workflow.Context, input MyInput) error { +// stream, err := workflowstreams.NewWorkflowStream(ctx, input.StreamState) +// if err != nil { +// return err +// } +// // Optionally publish from workflow code: +// _ = stream.Topic("events").Publish("hello") +// // ... run your workflow; the stream serves external publishers/subscribers. +// // Block until your workflow's exit condition is met (here, a done flag +// // set elsewhere, e.g. by a signal). +// return workflow.Await(ctx, func() bool { return done }) +// } +// +// Continue-as-new starts a fresh run with an empty history, so the stream's log +// and offsets must be carried across each boundary. This is a round-trip: when +// rolling over, return [WorkflowStream.NewContinueAsNewError] instead of a plain +// workflow.NewContinueAsNewError. It snapshots the stream state and hands it to +// your callback, which builds the next run's arguments — carry your own state +// forward alongside the captured state: +// +// return stream.NewContinueAsNewError(ctx, MyWorkflow, func(state *workflowstreams.WorkflowStreamState) []any { +// return []any{MyInput{ItemsProcessed: itemsProcessed, StreamState: state}} +// }) +// +// On the next run that captured state arrives as MyInput.StreamState, the value +// passed to [NewWorkflowStream] above: nil on a fresh start, non-nil after a +// roll-over, so the constructor rehydrates the log automatically. +// +// # Client side +// +// From an activity, starter, or any code with a [client.Client], use a +// [Client] to publish and subscribe: +// +// c := workflowstreams.NewClient(temporalClient, workflowID, workflowstreams.Options{}) +// defer c.Close(ctx) +// c.Topic("events").Publish("from outside", false) +// for item, err := range c.Subscribe(ctx, workflowstreams.SubscribeOptions{Topics: []string{"events"}}) { +// if err != nil { +// return err +// } +// var s string +// _ = converter.GetDefaultDataConverter().FromPayload(item.Data, &s) +// } +// +// # Cross-language protocol +// +// The handler names (see PublishSignalName, PollUpdateName, OffsetQueryName), +// the JSON envelope field names, and the per-item payload encoding (base64 of +// the marshaled temporal.api.common.v1.Payload) match workflow streams +// packages in other languages for interoperability. +// +// The data converter codec chain (encryption, compression) runs once on the +// signal/update envelope that carries each batch — not per item — so payloads +// are never double-encoded. Each per-item Payload still carries its encoding +// metadata so consumers can decode it with a payload converter. +package workflowstreams diff --git a/contrib/workflowstreams/go.mod b/contrib/workflowstreams/go.mod new file mode 100644 index 000000000..debcee4c2 --- /dev/null +++ b/contrib/workflowstreams/go.mod @@ -0,0 +1,35 @@ +module go.temporal.io/sdk/contrib/workflowstreams + +go 1.24.0 + +require ( + github.com/google/uuid v1.6.0 + github.com/stretchr/testify v1.11.1 + go.temporal.io/api v1.62.12 + go.temporal.io/sdk v1.44.1 + google.golang.org/protobuf v1.36.11 +) + +require ( + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/facebookgo/clock v0.0.0-20150410010913-600d898af40a // indirect + github.com/gogo/protobuf v1.3.2 // indirect + github.com/golang/mock v1.6.0 // indirect + github.com/grpc-ecosystem/go-grpc-middleware/v2 v2.3.2 // indirect + github.com/grpc-ecosystem/grpc-gateway/v2 v2.22.0 // indirect + github.com/nexus-rpc/sdk-go v0.6.0 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/robfig/cron v1.2.0 // indirect + github.com/stretchr/objx v0.5.2 // indirect + golang.org/x/net v0.49.0 // indirect + golang.org/x/sync v0.19.0 // indirect + golang.org/x/sys v0.40.0 // indirect + golang.org/x/text v0.33.0 // indirect + golang.org/x/time v0.3.0 // indirect + google.golang.org/genproto/googleapis/api v0.0.0-20260120221211-b8f7ae30c516 // indirect + google.golang.org/genproto/googleapis/rpc v0.0.0-20260120221211-b8f7ae30c516 // indirect + google.golang.org/grpc v1.79.3 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) + +replace go.temporal.io/sdk => ../../ diff --git a/contrib/workflowstreams/go.sum b/contrib/workflowstreams/go.sum new file mode 100644 index 000000000..831c156f2 --- /dev/null +++ b/contrib/workflowstreams/go.sum @@ -0,0 +1,117 @@ +github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= +github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/facebookgo/clock v0.0.0-20150410010913-600d898af40a h1:yDWHCSQ40h88yih2JAcL6Ls/kVkSE8GFACTGVnMPruw= +github.com/facebookgo/clock v0.0.0-20150410010913-600d898af40a/go.mod h1:7Ga40egUymuWXxAe151lTNnCv97MddSOVsjpPPkityA= +github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI= +github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= +github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= +github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= +github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= +github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= +github.com/golang/mock v1.6.0 h1:ErTB+efbowRARo13NNdxyJji2egdxLGQhRaY+DUumQc= +github.com/golang/mock v1.6.0/go.mod h1:p6yTPP+5HYm5mzsMV8JkE6ZKdX+/wYM6Hr+LicevLPs= +github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek= +github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= +github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= +github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/grpc-ecosystem/go-grpc-middleware/v2 v2.3.2 h1:sGm2vDRFUrQJO/Veii4h4zG2vvqG6uWNkBHSTqXOZk0= +github.com/grpc-ecosystem/go-grpc-middleware/v2 v2.3.2/go.mod h1:wd1YpapPLivG6nQgbf7ZkG1hhSOXDhhn4MLTknx2aAc= +github.com/grpc-ecosystem/grpc-gateway/v2 v2.22.0 h1:asbCHRVmodnJTuQ3qamDwqVOIjwqUPTYmYuemVOx+Ys= +github.com/grpc-ecosystem/grpc-gateway/v2 v2.22.0/go.mod h1:ggCgvZ2r7uOoQjOyu2Y1NhHmEPPzzuhWgcza5M1Ji1I= +github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8= +github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/nexus-rpc/sdk-go v0.6.0 h1:QRgnP2zTbxEbiyWG/aXH8uSC5LV/Mg1fqb19jb4DBlo= +github.com/nexus-rpc/sdk-go v0.6.0/go.mod h1:FHdPfVQwRuJFZFTF0Y2GOAxCrbIBNrcPna9slkGKPYk= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/robfig/cron v1.2.0 h1:ZjScXvvxeQ63Dbyxy76Fj3AT3Ut0aKsyd2/tl3DTMuQ= +github.com/robfig/cron v1.2.0/go.mod h1:JGuDeoQd7Z6yL4zQhZ3OPEVHB7fL6Ka6skscFHfmt2k= +github.com/rogpeppe/go-internal v1.11.0 h1:cWPaGQEPrBb5/AsnsZesgZZ9yb1OQ+GOISoDNXVBh4M= +github.com/rogpeppe/go-internal v1.11.0/go.mod h1:ddIwULY96R17DhadqLgMfk9H9tvdUzkipdSkR5nkCZA= +github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY= +github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= +github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= +github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= +github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= +go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ64= +go.opentelemetry.io/auto/sdk v1.2.1/go.mod h1:KRTj+aOaElaLi+wW1kO/DZRXwkF4C5xPbEe3ZiIhN7Y= +go.opentelemetry.io/otel v1.39.0 h1:8yPrr/S0ND9QEfTfdP9V+SiwT4E0G7Y5MO7p85nis48= +go.opentelemetry.io/otel v1.39.0/go.mod h1:kLlFTywNWrFyEdH0oj2xK0bFYZtHRYUdv1NklR/tgc8= +go.opentelemetry.io/otel/metric v1.39.0 h1:d1UzonvEZriVfpNKEVmHXbdf909uGTOQjA0HF0Ls5Q0= +go.opentelemetry.io/otel/metric v1.39.0/go.mod h1:jrZSWL33sD7bBxg1xjrqyDjnuzTUB0x1nBERXd7Ftcs= +go.opentelemetry.io/otel/sdk v1.39.0 h1:nMLYcjVsvdui1B/4FRkwjzoRVsMK8uL/cj0OyhKzt18= +go.opentelemetry.io/otel/sdk v1.39.0/go.mod h1:vDojkC4/jsTJsE+kh+LXYQlbL8CgrEcwmt1ENZszdJE= +go.opentelemetry.io/otel/sdk/metric v1.39.0 h1:cXMVVFVgsIf2YL6QkRF4Urbr/aMInf+2WKg+sEJTtB8= +go.opentelemetry.io/otel/sdk/metric v1.39.0/go.mod h1:xq9HEVH7qeX69/JnwEfp6fVq5wosJsY1mt4lLfYdVew= +go.opentelemetry.io/otel/trace v1.39.0 h1:2d2vfpEDmCJ5zVYz7ijaJdOF59xLomrvj7bjt6/qCJI= +go.opentelemetry.io/otel/trace v1.39.0/go.mod h1:88w4/PnZSazkGzz/w84VHpQafiU4EtqqlVdxWy+rNOA= +go.temporal.io/api v1.62.12 h1:627rVnItegQmrszg1bH4vfyc/1uNo5qCereCNkvZefw= +go.temporal.io/api v1.62.12/go.mod h1:iaxoP/9OXMJcQkETTECfwYq4cw/bj4nwov8b3ZLVnXM= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= +golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= +golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= +golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM= +golang.org/x/net v0.49.0 h1:eeHFmOGUTtaaPSGNmjBKpbng9MulQsJURQUAfUwY++o= +golang.org/x/net v0.49.0/go.mod h1:/ysNB2EvaqvesRkuLAyjI1ycPZlQHM3q01F02UY/MV8= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4= +golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.40.0 h1:DBZZqJ2Rkml6QMQsZywtnjnnGvHza6BTfYFWY9kjEWQ= +golang.org/x/sys v0.40.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.33.0 h1:B3njUFyqtHDUI5jMn1YIr5B0IE2U0qck04r6d4KPAxE= +golang.org/x/text v0.33.0/go.mod h1:LuMebE6+rBincTi9+xWTY8TztLzKHc/9C1uBCG27+q8= +golang.org/x/time v0.3.0 h1:rg5rLMjNzMS1RkNLzCG38eapWhnYLFYXDXj2gOlr8j4= +golang.org/x/time v0.3.0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= +golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= +golang.org/x/tools v0.1.1/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= +golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +gonum.org/v1/gonum v0.16.0 h1:5+ul4Swaf3ESvrOnidPp4GZbzf0mxVQpDCYUQE7OJfk= +gonum.org/v1/gonum v0.16.0/go.mod h1:fef3am4MQ93R2HHpKnLk4/Tbh/s0+wqD5nfa6Pnwy4E= +google.golang.org/genproto/googleapis/api v0.0.0-20260120221211-b8f7ae30c516 h1:vmC/ws+pLzWjj/gzApyoZuSVrDtF1aod4u/+bbj8hgM= +google.golang.org/genproto/googleapis/api v0.0.0-20260120221211-b8f7ae30c516/go.mod h1:p3MLuOwURrGBRoEyFHBT3GjUwaCQVKeNqqWxlcISGdw= +google.golang.org/genproto/googleapis/rpc v0.0.0-20260120221211-b8f7ae30c516 h1:sNrWoksmOyF5bvJUcnmbeAmQi8baNhqg5IWaI3llQqU= +google.golang.org/genproto/googleapis/rpc v0.0.0-20260120221211-b8f7ae30c516/go.mod h1:j9x/tPzZkyxcgEFkiKEEGxfvyumM01BEtsW8xzOahRQ= +google.golang.org/grpc v1.79.3 h1:sybAEdRIEtvcD68Gx7dmnwjZKlyfuc61Dyo9pGXXkKE= +google.golang.org/grpc v1.79.3/go.mod h1:KmT0Kjez+0dde/v2j9vzwoAScgEPx/Bw1CYChhHLrHQ= +google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE= +google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/contrib/workflowstreams/publisher.go b/contrib/workflowstreams/publisher.go new file mode 100644 index 000000000..90348c9ea --- /dev/null +++ b/contrib/workflowstreams/publisher.go @@ -0,0 +1,302 @@ +package workflowstreams + +import ( + "context" + "errors" + "fmt" + "strings" + "sync" + "time" + + "github.com/google/uuid" + commonpb "go.temporal.io/api/common/v1" + "go.temporal.io/sdk/converter" +) + +// FlushTimeoutError is returned when a flush retry exceeds MaxRetryDuration. +// The pending batch is dropped; if the signal had already been delivered the +// items are in the log, otherwise they are lost. +type FlushTimeoutError struct { + msg string +} + +func (e *FlushTimeoutError) Error() string { return e.msg } + +type bufItem struct { + topic string + value any +} + +// publisher owns the client-side publish path: it buffers published values, +// batches them, and sends each batch to the workflow via the injected signal +// function. It assigns the per-publisher dedup key (a stable publisher ID plus +// a monotonic sequence advanced only on a confirmed send) so the workflow can +// drop duplicates, and it retries a failed batch until maxRetryDur elapses. +// +// The signal func is injected (rather than holding a client.Client) so the +// publish path can be exercised in isolation. +type publisher struct { + signal func(ctx context.Context, in PublishInput) error + dc converter.DataConverter + publisherID string + batchInterval time.Duration + maxBatchSize int + maxRetryDur time.Duration + + mu sync.Mutex + buffer []bufItem + pending []PublishEntry + pendingSeq int64 + sequence int64 + pendingStart time.Time + started bool + closed bool + err error + + flushMu sync.Mutex // serializes doFlush + trigger chan struct{} + stop chan struct{} + done chan struct{} +} + +func newPublisher(signal func(context.Context, PublishInput) error, dc converter.DataConverter, opts Options) *publisher { + p := &publisher{ + signal: signal, + dc: dc, + publisherID: strings.ReplaceAll(uuid.NewString(), "-", "")[:16], + batchInterval: opts.BatchInterval, + maxBatchSize: opts.MaxBatchSize, + maxRetryDur: opts.MaxRetryDuration, + trigger: make(chan struct{}, 1), + stop: make(chan struct{}), + done: make(chan struct{}), + } + if p.batchInterval <= 0 { + p.batchInterval = defaultBatchInterval + } + if p.maxRetryDur <= 0 { + p.maxRetryDur = defaultMaxRetryDuration + } + return p +} + +// publish buffers a value and lazily starts the background flush loop. It +// triggers an immediate flush on forceFlush or once the buffer reaches +// maxBatchSize. +func (p *publisher) publish(topic string, value any, forceFlush bool) { + p.mu.Lock() + p.buffer = append(p.buffer, bufItem{topic: topic, value: value}) + trigger := forceFlush || (p.maxBatchSize > 0 && len(p.buffer) >= p.maxBatchSize) + closed := p.closed + p.mu.Unlock() + + if !closed { + p.ensureStarted() + } + if trigger { + p.triggerFlush() + } +} + +func (p *publisher) ensureStarted() { + p.mu.Lock() + defer p.mu.Unlock() + if p.started || p.closed { + return + } + p.started = true + go p.run() +} + +func (p *publisher) triggerFlush() { + select { + case p.trigger <- struct{}{}: + default: + } +} + +func (p *publisher) run() { + defer close(p.done) + ticker := time.NewTicker(p.batchInterval) + defer ticker.Stop() + for { + select { + case <-p.stop: + return + case <-p.trigger: + case <-ticker.C: + } + if err := p.doFlush(context.Background()); err != nil { + var fte *FlushTimeoutError + if errors.As(err, &fte) { + // The pending batch was dropped and can't be recovered. Stash + // the error so flush/close surface it and stop the loop. + p.mu.Lock() + p.err = err + p.mu.Unlock() + return + } + // Transient failure: pending stays set for retry on the next tick. + } + } +} + +// doFlush sends the pending batch (retry) or encodes and sends the buffer (new +// batch). It is serialized so concurrent callers send sequentially. +func (p *publisher) doFlush(ctx context.Context) error { + p.flushMu.Lock() + defer p.flushMu.Unlock() + + var batch []PublishEntry + var seq int64 + + p.mu.Lock() + switch { + case p.pending != nil: + if !p.pendingStart.IsZero() && time.Since(p.pendingStart) > p.maxRetryDur { + // Advance the confirmed sequence so the next batch gets a fresh + // sequence number. Without this the next batch reuses pendingSeq, + // which the workflow may have already accepted — causing silent + // dedup (data loss). + p.sequence = p.pendingSeq + p.pending = nil + p.pendingSeq = 0 + p.pendingStart = time.Time{} + p.mu.Unlock() + return &FlushTimeoutError{msg: fmt.Sprintf( + "workflowstreams: flush retry exceeded MaxRetryDuration (%s); pending batch dropped", p.maxRetryDur)} + } + batch = p.pending + seq = p.pendingSeq + case len(p.buffer) > 0: + encoded, err := p.encodeBuffer(p.buffer) + if err != nil { + p.mu.Unlock() + return err // buffer left intact for a later flush + } + p.buffer = nil + batch = encoded + seq = p.sequence + 1 + p.pending = batch + p.pendingSeq = seq + p.pendingStart = time.Now() + default: + p.mu.Unlock() + return nil + } + p.mu.Unlock() + + // On failure the signal returns an error and pending stays set for retry. + if err := p.signal(ctx, PublishInput{Items: batch, PublisherID: p.publisherID, Sequence: seq}); err != nil { + return err + } + + p.mu.Lock() + p.sequence = seq + p.pending = nil + p.pendingSeq = 0 + p.pendingStart = time.Time{} + p.mu.Unlock() + return nil +} + +func (p *publisher) encodeBuffer(items []bufItem) ([]PublishEntry, error) { + out := make([]PublishEntry, len(items)) + for i, it := range items { + payload, ok := it.value.(*commonpb.Payload) + if !ok { + var err error + payload, err = p.dc.ToPayload(it.value) + if err != nil { + return nil, fmt.Errorf("workflowstreams: convert value: %w", err) + } + } + data, err := encodePayloadWire(payload) + if err != nil { + return nil, err + } + out[i] = PublishEntry{Topic: it.topic, Data: data} + } + return out, nil +} + +// flush sends buffered (and pending) items and waits for server confirmation. +// It returns once the items buffered at call time have been signaled and +// acknowledged, or a FlushTimeoutError if a pending batch cannot be sent within +// maxRetryDur. +func (p *publisher) flush(ctx context.Context) error { + if err := p.takeError(); err != nil { + return err + } + + p.mu.Lock() + if p.pending == nil && len(p.buffer) == 0 { + p.mu.Unlock() + return nil + } + baseSeq := p.sequence + if p.pending != nil { + baseSeq = p.pendingSeq + } + targetSeq := baseSeq + if len(p.buffer) > 0 { + targetSeq = baseSeq + 1 + } + p.mu.Unlock() + + for { + p.mu.Lock() + cur := p.sequence + p.mu.Unlock() + if cur >= targetSeq { + break + } + if err := p.doFlush(ctx); err != nil { + return err + } + } + return p.takeError() +} + +// close stops the background flush loop and drains any remaining items, +// surfacing a deferred FlushTimeoutError from a prior background failure. +func (p *publisher) close(ctx context.Context) error { + p.mu.Lock() + if p.closed { + p.mu.Unlock() + return nil + } + p.closed = true + started := p.started + p.mu.Unlock() + + if started { + close(p.stop) + <-p.done + } + + // Final drain: a single doFlush processes either pending OR the buffer. + for { + p.mu.Lock() + more := p.pending != nil || len(p.buffer) > 0 + p.mu.Unlock() + if !more { + break + } + if err := p.doFlush(ctx); err != nil { + return err + } + } + return p.takeError() +} + +func (p *publisher) takeError() error { + p.mu.Lock() + defer p.mu.Unlock() + if p.err != nil { + err := p.err + p.err = nil + return err + } + return nil +} diff --git a/contrib/workflowstreams/subscribe_test.go b/contrib/workflowstreams/subscribe_test.go new file mode 100644 index 000000000..b6aae084a --- /dev/null +++ b/contrib/workflowstreams/subscribe_test.go @@ -0,0 +1,347 @@ +package workflowstreams + +import ( + "context" + "errors" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/require" + commonpb "go.temporal.io/api/common/v1" + enumspb "go.temporal.io/api/enums/v1" + workflowpb "go.temporal.io/api/workflow/v1" + "go.temporal.io/api/workflowservice/v1" + "go.temporal.io/sdk/client" + "go.temporal.io/sdk/converter" + "go.temporal.io/sdk/temporal" +) + +// pollStep is a single scripted reply to one Subscribe poll. updateErr fails the +// UpdateWorkflow call itself; getErr fails the handle.Get; otherwise result is +// returned. runID is the run the update is admitted to, surfaced via the handle's +// RunID() so Subscribe can describe that specific run on failure. +type pollStep struct { + runID string + updateErr error + getErr error + result PollResult +} + +// errStepsExhausted is returned once a fake runs out of scripted poll steps. The +// default describe status (COMPLETED) makes the loop treat it as a clean, +// terminal end, so a test that forgets to break can't spin forever. +var errStepsExhausted = errors.New("fake: poll steps exhausted") + +// fakeSubClient scripts UpdateWorkflow/DescribeWorkflowExecution responses so the +// Subscribe polling loop can be exercised without a server. Embedding +// client.Client satisfies the rest of the interface; the unused methods panic if +// called, which the subscribe tests never do. +type fakeSubClient struct { + client.Client + + mu sync.Mutex + steps []pollStep + idx int + polls []PollInput // PollInput captured per UpdateWorkflow call, in order + + // describeByRun maps a run id to the status DescribeWorkflowExecution reports + // for it. A run not present defaults to COMPLETED, so loops terminate cleanly. + describeByRun map[string]enumspb.WorkflowExecutionStatus + describeRuns []string // run ids passed to DescribeWorkflowExecution, in order + describeErr error +} + +func (f *fakeSubClient) UpdateWorkflow(_ context.Context, options client.UpdateWorkflowOptions) (client.WorkflowUpdateHandle, error) { + f.mu.Lock() + defer f.mu.Unlock() + f.polls = append(f.polls, options.Args[0].(PollInput)) + if f.idx >= len(f.steps) { + return nil, errStepsExhausted + } + step := f.steps[f.idx] + f.idx++ + if step.updateErr != nil { + return nil, step.updateErr + } + return &fakeUpdateHandle{runID: step.runID, result: step.result, err: step.getErr}, nil +} + +func (f *fakeSubClient) DescribeWorkflowExecution(_ context.Context, _, runID string) (*workflowservice.DescribeWorkflowExecutionResponse, error) { + f.mu.Lock() + defer f.mu.Unlock() + f.describeRuns = append(f.describeRuns, runID) + if f.describeErr != nil { + return nil, f.describeErr + } + status, ok := f.describeByRun[runID] + if !ok { + status = enumspb.WORKFLOW_EXECUTION_STATUS_COMPLETED + } + return &workflowservice.DescribeWorkflowExecutionResponse{ + WorkflowExecutionInfo: &workflowpb.WorkflowExecutionInfo{Status: status}, + }, nil +} + +func (f *fakeSubClient) recordedPolls() []PollInput { + f.mu.Lock() + defer f.mu.Unlock() + return append([]PollInput(nil), f.polls...) +} + +func (f *fakeSubClient) recordedDescribeRuns() []string { + f.mu.Lock() + defer f.mu.Unlock() + return append([]string(nil), f.describeRuns...) +} + +// fakeUpdateHandle returns a scripted PollResult (or error) from Get and reports +// its run id via RunID(). The remaining WorkflowUpdateHandle methods are unused by +// Subscribe and panic via the nil embedded interface if called. +type fakeUpdateHandle struct { + client.WorkflowUpdateHandle + runID string + result PollResult + err error +} + +func (h *fakeUpdateHandle) RunID() string { return h.runID } + +func (h *fakeUpdateHandle) Get(_ context.Context, valuePtr interface{}) error { + if h.err != nil { + return h.err + } + if p, ok := valuePtr.(*PollResult); ok { + *p = h.result + } + return nil +} + +// newSubClient builds a Client wired to fc directly, bypassing NewClient so no +// background publisher is started. +func newSubClient(fc *fakeSubClient) *Client { + return &Client{c: fc, workflowID: "wf", followCAN: true} +} + +// wireItem encodes value into a WireItem the way a workflow would. +func wireItem(t *testing.T, topic, value string, offset int64) WireItem { + t.Helper() + payload, err := converter.GetDefaultDataConverter().ToPayload(value) + require.NoError(t, err) + enc, err := encodePayloadWire(payload) + require.NoError(t, err) + return WireItem{Topic: topic, Data: enc, Offset: offset} +} + +func decodeItem(t *testing.T, p *commonpb.Payload) string { + t.Helper() + var s string + require.NoError(t, converter.GetDefaultDataConverter().FromPayload(p, &s)) + return s +} + +func TestSubscribeDeliversItemsAndAdvancesOffset(t *testing.T) { + fc := &fakeSubClient{steps: []pollStep{ + {result: PollResult{Items: []WireItem{wireItem(t, "evt", "a", 1)}, NextOffset: 2, MoreReady: true}}, + {result: PollResult{Items: []WireItem{wireItem(t, "evt", "b", 2)}, NextOffset: 3, MoreReady: true}}, + }} + c := newSubClient(fc) + + var got []string + var gotOffsets []int64 + for item, err := range c.Subscribe(context.Background(), SubscribeOptions{FromOffset: 1}) { + require.NoError(t, err) + got = append(got, decodeItem(t, item.Data)) + gotOffsets = append(gotOffsets, item.Offset) + if len(got) == 2 { + break + } + } + + require.Equal(t, []string{"a", "b"}, got) + require.Equal(t, []int64{1, 2}, gotOffsets) + + polls := fc.recordedPolls() + require.GreaterOrEqual(t, len(polls), 2) + require.EqualValues(t, 1, polls[0].FromOffset, "first poll uses the requested offset") + require.EqualValues(t, 2, polls[1].FromOffset, "second poll advances to the prior NextOffset") +} + +func TestSubscribeTruncationResetsOffset(t *testing.T) { + truncated := temporal.NewNonRetryableApplicationError("truncated", ErrTypeTruncatedOffset, nil) + fc := &fakeSubClient{steps: []pollStep{ + {getErr: truncated}, + {result: PollResult{Items: []WireItem{wireItem(t, "evt", "a", 0)}, NextOffset: 1, MoreReady: true}}, + }} + c := newSubClient(fc) + + var got []string + for item, err := range c.Subscribe(context.Background(), SubscribeOptions{FromOffset: 5}) { + require.NoError(t, err) + got = append(got, decodeItem(t, item.Data)) + break + } + + require.Equal(t, []string{"a"}, got) + + polls := fc.recordedPolls() + require.GreaterOrEqual(t, len(polls), 2) + require.EqualValues(t, 5, polls[0].FromOffset, "first poll uses the requested offset") + require.EqualValues(t, 0, polls[1].FromOffset, "truncation restarts from the beginning") + require.Empty(t, fc.recordedDescribeRuns(), "truncation is handled without describing the workflow") +} + +func TestSubscribeTerminalEndsCleanly(t *testing.T) { + fc := &fakeSubClient{ + steps: []pollStep{{runID: "R1", getErr: errors.New("workflow gone")}}, + describeByRun: map[string]enumspb.WorkflowExecutionStatus{ + // The polled run itself ended — not a rollover. + "R1": enumspb.WORKFLOW_EXECUTION_STATUS_COMPLETED, + }, + } + c := newSubClient(fc) + + var yields int + for _, err := range c.Subscribe(context.Background(), SubscribeOptions{}) { + yields++ + require.NoError(t, err, "terminal workflow should end the stream without surfacing an error") + } + require.Zero(t, yields, "no items and no error are yielded on a clean terminal end") + require.Equal(t, []string{"R1"}, fc.recordedDescribeRuns()[:1], "the polled run is the one described") +} + +func TestSubscribeContinueAsNewRetries(t *testing.T) { + // The poll on run R1 fails because R1 continued-as-new mid-poll; describing R1 + // reports CONTINUED_AS_NEW, so Subscribe retries and the next poll lands on the + // successor run R2. + fc := &fakeSubClient{ + steps: []pollStep{ + {runID: "R1", getErr: errors.New("update lost to continue-as-new")}, + {runID: "R2", result: PollResult{Items: []WireItem{wireItem(t, "evt", "after-can", 1)}, NextOffset: 2, MoreReady: true}}, + }, + describeByRun: map[string]enumspb.WorkflowExecutionStatus{ + "R1": enumspb.WORKFLOW_EXECUTION_STATUS_CONTINUED_AS_NEW, + }, + } + c := newSubClient(fc) + + var got []string + for item, err := range c.Subscribe(context.Background(), SubscribeOptions{}) { + require.NoError(t, err) + got = append(got, decodeItem(t, item.Data)) + break + } + + require.Equal(t, []string{"after-can"}, got) + require.GreaterOrEqual(t, len(fc.recordedPolls()), 2, "the poll is retried after following continue-as-new") + require.Equal(t, []string{"R1"}, fc.recordedDescribeRuns(), "the rollover is detected by describing the polled run R1, not the latest run") +} + +func TestSubscribeRetriesWhileDraining(t *testing.T) { + // While the workflow is detaching for continue-as-new, the poll validator + // rejects with ErrTypeStreamDraining. Subscribe must back off and retry + // rather than surface the rejection; the next poll then lands on the + // successor run. + draining := temporal.NewApplicationError("workflow is draining", ErrTypeStreamDraining) + fc := &fakeSubClient{steps: []pollStep{ + {runID: "R1", getErr: draining}, + {runID: "R2", result: PollResult{Items: []WireItem{wireItem(t, "evt", "after-can", 1)}, NextOffset: 2, MoreReady: true}}, + }} + c := newSubClient(fc) + + var got []string + var gotErr error + for item, err := range c.Subscribe(context.Background(), SubscribeOptions{PollCooldown: time.Millisecond}) { + if err != nil { + gotErr = err + break + } + got = append(got, decodeItem(t, item.Data)) + break + } + + require.NoError(t, gotErr, "a draining rejection must not surface as an error") + require.Equal(t, []string{"after-can"}, got) + require.GreaterOrEqual(t, len(fc.recordedPolls()), 2, "the poll is retried after the draining rejection") + require.Empty(t, fc.recordedDescribeRuns(), "a draining rejection is retried without describing the workflow") +} + +func TestSubscribeSurfacesNonTerminalError(t *testing.T) { + // A transient failure on a run that is still RUNNING (no rollover, not + // terminal) must surface rather than retry forever. + boom := errors.New("boom") + fc := &fakeSubClient{ + steps: []pollStep{{runID: "R1", getErr: boom}}, + describeByRun: map[string]enumspb.WorkflowExecutionStatus{ + "R1": enumspb.WORKFLOW_EXECUTION_STATUS_RUNNING, + }, + } + c := newSubClient(fc) + + var gotErr error + var yields int + for _, err := range c.Subscribe(context.Background(), SubscribeOptions{}) { + yields++ + gotErr = err + } + require.Equal(t, 1, yields) + require.ErrorIs(t, gotErr, boom, "a non-terminal poll error is surfaced to the caller") +} + +func TestSubscribeContextCanceledBeforePolling(t *testing.T) { + fc := &fakeSubClient{steps: []pollStep{ + {result: PollResult{Items: []WireItem{wireItem(t, "evt", "a", 1)}, NextOffset: 2}}, + }} + c := newSubClient(fc) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + var gotErr error + for _, err := range c.Subscribe(ctx, SubscribeOptions{}) { + gotErr = err + } + require.ErrorIs(t, gotErr, context.Canceled) + require.Empty(t, fc.recordedPolls(), "a canceled context short-circuits before any poll") +} + +func TestSubscribeCooldownCanceledByContext(t *testing.T) { + // First poll succeeds with no items and MoreReady=false, so the loop enters + // the cooldown wait. A long cooldown plus a canceled context proves the wait + // is interruptible rather than blocking for the full PollCooldown. + fc := &fakeSubClient{steps: []pollStep{ + {result: PollResult{Items: nil, NextOffset: 1, MoreReady: false}}, + }} + c := newSubClient(fc) + + ctx, cancel := context.WithCancel(context.Background()) + go func() { + time.Sleep(20 * time.Millisecond) + cancel() + }() + + start := time.Now() + var gotErr error + for _, err := range c.Subscribe(ctx, SubscribeOptions{PollCooldown: time.Hour}) { + gotErr = err + } + require.ErrorIs(t, gotErr, context.Canceled) + require.Less(t, time.Since(start), 5*time.Second, "cooldown wait should be interrupted by context cancellation") +} + +func TestSubscribeCooldownAppliedWhenNotMoreReady(t *testing.T) { + const cooldown = 60 * time.Millisecond + fc := &fakeSubClient{steps: []pollStep{ + {result: PollResult{Items: nil, NextOffset: 1, MoreReady: false}}, // triggers cooldown + {result: PollResult{Items: []WireItem{wireItem(t, "evt", "a", 1)}, NextOffset: 2, MoreReady: true}}, + }} + c := newSubClient(fc) + + start := time.Now() + for item, err := range c.Subscribe(context.Background(), SubscribeOptions{PollCooldown: cooldown}) { + require.NoError(t, err) + require.Equal(t, "a", decodeItem(t, item.Data)) + break + } + require.GreaterOrEqual(t, time.Since(start), cooldown/2, "a cooldown is applied between polls when MoreReady is false") +} diff --git a/contrib/workflowstreams/topic_handle.go b/contrib/workflowstreams/topic_handle.go new file mode 100644 index 000000000..89909d3b5 --- /dev/null +++ b/contrib/workflowstreams/topic_handle.go @@ -0,0 +1,28 @@ +package workflowstreams + +import ( + "context" + "iter" +) + +// TopicHandle publishes to and subscribes from a single topic. +type TopicHandle struct { + name string + client *Client +} + +// Name returns the topic name. +func (h *TopicHandle) Name() string { return h.name } + +// Publish buffers value for publishing on this topic. value goes through the +// client's data converter at flush time; a pre-built *commonpb.Payload bypasses +// conversion. Pass forceFlush to wake the publisher and send immediately. +func (h *TopicHandle) Publish(value any, forceFlush bool) { + h.client.pub.publish(h.name, value, forceFlush) +} + +// Subscribe returns an iterator over items on this topic, starting at +// fromOffset. See Client.Subscribe. +func (h *TopicHandle) Subscribe(ctx context.Context, fromOffset int64) iter.Seq2[WorkflowStreamItem, error] { + return h.client.Subscribe(ctx, SubscribeOptions{Topics: []string{h.name}, FromOffset: fromOffset}) +} diff --git a/contrib/workflowstreams/types.go b/contrib/workflowstreams/types.go new file mode 100644 index 000000000..9514193bf --- /dev/null +++ b/contrib/workflowstreams/types.go @@ -0,0 +1,107 @@ +package workflowstreams + +import ( + "time" + + commonpb "go.temporal.io/api/common/v1" +) + +// Fixed handler names. These are part of the cross-language wire protocol and +// match the Python and TypeScript packages exactly. The Go SDK normally +// reserves the "__temporal_" prefix, but explicitly permits the +// "__temporal_workflow_stream_" sub-namespace for this package. +const ( + // PublishSignalName is the signal external publishers send to append a + // batch of items to the stream. + PublishSignalName = "__temporal_workflow_stream_publish" + // PollUpdateName is the update subscribers send to long-poll for new items. + PollUpdateName = "__temporal_workflow_stream_poll" + // OffsetQueryName is the query that returns the current global offset. + OffsetQueryName = "__temporal_workflow_stream_offset" +) + +// Error types surfaced by the poll update and truncate. These match the type +// strings used by the Python and TypeScript packages. +const ( + // ErrTypeTruncatedOffset is the ApplicationError type returned by the poll + // update when the requested offset has already been truncated. + ErrTypeTruncatedOffset = "TruncatedOffset" + // ErrTypeTruncateOutOfRange is the ApplicationError type returned by + // Truncate when the requested offset is past the end of the log. + ErrTypeTruncateOutOfRange = "TruncateOutOfRange" + // ErrTypeStreamDraining is the ApplicationError type the poll update's + // validator returns while the stream is detaching for continue-as-new. It + // tells a subscriber the rollover is in progress so it retries (rather than + // surfacing an error) until the poll lands on the successor run. + ErrTypeStreamDraining = "StreamDraining" +) + +// maxPollResponseBytes caps the estimated wire size of a single poll response. +// Responses that would exceed this are truncated and signal MoreReady so the +// subscriber pages through the remainder. +const maxPollResponseBytes = 1_000_000 + +// Default option values, matching the Python and TypeScript packages. +const ( + defaultBatchInterval = 2 * time.Second + defaultPollCooldown = 100 * time.Millisecond + defaultPublisherTTL = 15 * time.Minute + defaultMaxRetryDuration = 10 * time.Minute +) + +// PublishEntry is a single entry within a publish batch on the wire. Data is a +// base64-encoded, marshaled commonpb.Payload. +type PublishEntry struct { + Topic string `json:"topic"` + Data string `json:"data"` +} + +// PublishInput is the signal payload carrying a batch of entries to publish, +// along with the dedup fields. +type PublishInput struct { + Items []PublishEntry `json:"items"` + PublisherID string `json:"publisher_id"` + Sequence int64 `json:"sequence"` +} + +// PollInput is the update payload: a request to poll for new items. +type PollInput struct { + Topics []string `json:"topics"` + FromOffset int64 `json:"from_offset"` +} + +// WireItem is the wire representation of a stream item. Data is a +// base64-encoded, marshaled commonpb.Payload. +type WireItem struct { + Topic string `json:"topic"` + Data string `json:"data"` + Offset int64 `json:"offset"` +} + +// PollResult is the update response: items matching the poll request. When +// MoreReady is true the response was truncated to stay within size limits and +// the subscriber should poll again immediately rather than applying a cooldown. +type PollResult struct { + Items []WireItem `json:"items"` + NextOffset int64 `json:"next_offset"` + MoreReady bool `json:"more_ready"` +} + +// WorkflowStreamState is a serializable snapshot of stream state for +// continue-as-new. Thread a *WorkflowStreamState field through your workflow +// input and pass it to New. +type WorkflowStreamState struct { + Log []WireItem `json:"log"` + BaseOffset int64 `json:"base_offset"` + PublisherSeqs map[string]int64 `json:"publisher_sequences"` + PublisherLastSeen map[string]float64 `json:"publisher_last_seen"` +} + +// WorkflowStreamItem is a single decoded item yielded by a subscription. Data +// is the raw Payload; decode it at the call site with a payload converter, +// e.g. converter.GetDefaultDataConverter().FromPayload(item.Data, &dst). +type WorkflowStreamItem struct { + Topic string + Data *commonpb.Payload + Offset int64 +} diff --git a/contrib/workflowstreams/workflow.go b/contrib/workflowstreams/workflow.go new file mode 100644 index 000000000..a8ccd4795 --- /dev/null +++ b/contrib/workflowstreams/workflow.go @@ -0,0 +1,345 @@ +package workflowstreams + +import ( + "fmt" + "maps" + "time" + + commonpb "go.temporal.io/api/common/v1" + "go.temporal.io/sdk/converter" + "go.temporal.io/sdk/temporal" + "go.temporal.io/sdk/workflow" +) + +// internalEntry is a single decoded log entry held in workflow memory. +type internalEntry struct { + topic string + payload *commonpb.Payload +} + +// WorkflowStream is the workflow-side stream object: an append-only log served +// to external publishers (via signal), subscribers (via update), and offset +// queries (via query). +// +// Construct it once at the start of your workflow function with NewWorkflowStream. The +// constructor registers all three handlers on the current workflow. +type WorkflowStream struct { + ctx workflow.Context + dc converter.DataConverter + + log []internalEntry + baseOffset int64 + publisherSeqs map[string]int64 + publisherLastSeen map[string]float64 + draining bool + + topicHandles map[string]*WorkflowTopicHandle +} + +// WorkflowStreamOption configures a WorkflowStream at construction. +type WorkflowStreamOption func(*workflowStreamConfig) + +type workflowStreamConfig struct { + payloadConverters []converter.PayloadConverter +} + +// WithPayloadConverters customizes how values published from workflow code (via +// WorkflowTopicHandle.Publish) are serialized into per-item Payloads. They are +// combined into a CompositeDataConverter in the order given, so the last one +// should be a catch-all such as converter.NewJSONPayloadConverter. +// +// As on the client side, only payload conversion happens here — never a payload +// codec. The worker's codec chain runs once on the poll-update response that +// carries each batch to subscribers, so encoding items here too would +// double-encode them; the []PayloadConverter type makes that impossible. +// +// Note: there is no public accessor for a converter set via +// workflow.WithDataConverter, so it cannot be picked up automatically; pass the +// matching payload converters here to keep workflow-side publishes consistent +// with the rest of your workflow. Default: converter.GetDefaultDataConverter(). +func WithPayloadConverters(pcs ...converter.PayloadConverter) WorkflowStreamOption { + return func(c *workflowStreamConfig) { c.payloadConverters = pcs } +} + +// NewWorkflowStream constructs a WorkflowStream and registers its signal, update, and +// query handlers on the current workflow. Pass priorState (which may be nil) to +// restore state carried across a continue-as-new boundary. +func NewWorkflowStream(ctx workflow.Context, priorState *WorkflowStreamState, opts ...WorkflowStreamOption) (*WorkflowStream, error) { + var cfg workflowStreamConfig + for _, opt := range opts { + opt(&cfg) + } + // A composite of PayloadConverters is codec-free, so workflow-published + // items are never double-encoded against the worker's response codec. + var dc converter.DataConverter = converter.GetDefaultDataConverter() + if len(cfg.payloadConverters) > 0 { + dc = converter.NewCompositeDataConverter(cfg.payloadConverters...) + } + + s := &WorkflowStream{ + ctx: ctx, + dc: dc, + publisherSeqs: map[string]int64{}, + publisherLastSeen: map[string]float64{}, + topicHandles: map[string]*WorkflowTopicHandle{}, + } + + if priorState != nil { + s.baseOffset = priorState.BaseOffset + for _, item := range priorState.Log { + payload, err := decodePayloadWire(item.Data) + if err != nil { + return nil, fmt.Errorf("workflowstreams: restore log: %w", err) + } + s.log = append(s.log, internalEntry{topic: item.Topic, payload: payload}) + } + maps.Copy(s.publisherSeqs, priorState.PublisherSeqs) + maps.Copy(s.publisherLastSeen, priorState.PublisherLastSeen) + } + + // Signals are modeled as channels in Go; drain the publish channel from a + // dedicated workflow coroutine. + workflow.Go(ctx, func(ctx workflow.Context) { + ch := workflow.GetSignalChannel(ctx, PublishSignalName) + for { + var input PublishInput + if !ch.Receive(ctx, &input) { + return // channel closed (workflow completing) + } + s.onPublish(input) + } + }) + + if err := workflow.SetUpdateHandlerWithOptions(ctx, PollUpdateName, s.onPoll, workflow.UpdateHandlerOptions{ + Validator: func(_ PollInput) error { + if s.draining { + return temporal.NewApplicationError( + "workflow is draining for continue-as-new", ErrTypeStreamDraining) + } + return nil + }, + }); err != nil { + return nil, err + } + + if err := workflow.SetQueryHandler(ctx, OffsetQueryName, func() (int64, error) { + return s.baseOffset + int64(len(s.log)), nil + }); err != nil { + return nil, err + } + + return s, nil +} + +// Topic returns a handle for publishing to name. Repeated calls with the same +// name return the same handle. +func (s *WorkflowStream) Topic(name string) *WorkflowTopicHandle { + if h, ok := s.topicHandles[name]; ok { + return h + } + h := &WorkflowTopicHandle{name: name, stream: s} + s.topicHandles[name] = h + return h +} + +// DetachPollers unblocks all waiting poll handlers and rejects new polls. Used +// before continue-as-new. +func (s *WorkflowStream) DetachPollers() { + s.draining = true +} + +// GetState returns a serializable snapshot of stream state for continue-as-new. +// It drops per-publisher sequence tracking for publishers that have not sent a +// batch within publisherTTL. +func (s *WorkflowStream) GetState(publisherTTL time.Duration) (*WorkflowStreamState, error) { + now := float64(workflow.Now(s.ctx).Unix()) + ttlSeconds := publisherTTL.Seconds() + + seqs := map[string]int64{} + seen := map[string]float64{} + for id, seq := range s.publisherSeqs { + ts := s.publisherLastSeen[id] + if now-ts < ttlSeconds { + seqs[id] = seq + seen[id] = ts + } + } + + log := make([]WireItem, 0, len(s.log)) + for _, entry := range s.log { + data, err := encodePayloadWire(entry.payload) + if err != nil { + return nil, err + } + // Per-item offset is re-derivable from baseOffset + index on reload. + log = append(log, WireItem{Topic: entry.topic, Data: data, Offset: 0}) + } + + return &WorkflowStreamState{ + Log: log, + BaseOffset: s.baseOffset, + PublisherSeqs: seqs, + PublisherLastSeen: seen, + }, nil +} + +// NewContinueAsNewError returns a continue-as-new error for wfn that you must +// return from your workflow function to end the current run, mirroring +// workflow.NewContinueAsNewError. Unlike that constructor it also drains pollers +// and blocks until in-flight handlers finish before capturing state, so it can +// take a moment to return. buildArgs receives the post-detach stream state and +// returns the positional arguments for the new run; thread the returned +// *WorkflowStreamState into your workflow input so the stream survives. +// +// return stream.NewContinueAsNewError(ctx, MyWorkflow, func(state *workflowstreams.WorkflowStreamState) []any { +// return []any{state} +// }) +// +// State is captured with the default 15-minute publisher TTL. For a custom TTL, +// use the manual recipe: DetachPollers, Await(AllHandlersFinished), GetState, +// then workflow.NewContinueAsNewError. +func (s *WorkflowStream) NewContinueAsNewError(ctx workflow.Context, wfn any, buildArgs func(state *WorkflowStreamState) []any) error { + s.DetachPollers() + if err := workflow.Await(ctx, func() bool { return workflow.AllHandlersFinished(ctx) }); err != nil { + return err + } + state, err := s.GetState(defaultPublisherTTL) + if err != nil { + return err + } + return workflow.NewContinueAsNewError(ctx, wfn, buildArgs(state)...) +} + +// Truncate discards log entries before upToOffset and advances the base offset. +// After truncation, polls requesting an offset before the new base receive a +// TruncatedOffset error. Truncate returns a non-retryable TruncateOutOfRange +// ApplicationError if upToOffset is past the end of the log. +func (s *WorkflowStream) Truncate(upToOffset int64) error { + logIndex := upToOffset - s.baseOffset + if logIndex <= 0 { + return nil + } + if logIndex > int64(len(s.log)) { + return temporal.NewNonRetryableApplicationError( + fmt.Sprintf("cannot truncate to offset %d: only %d items exist", upToOffset, s.baseOffset+int64(len(s.log))), + ErrTypeTruncateOutOfRange, nil) + } + s.log = append([]internalEntry(nil), s.log[logIndex:]...) + s.baseOffset = upToOffset + return nil +} + +func (s *WorkflowStream) onPublish(input PublishInput) { + if input.PublisherID != "" { + if input.Sequence <= s.publisherSeqs[input.PublisherID] { + return // duplicate — skip + } + s.publisherSeqs[input.PublisherID] = input.Sequence + s.publisherLastSeen[input.PublisherID] = float64(workflow.Now(s.ctx).Unix()) + } + for _, entry := range input.Items { + payload, err := decodePayloadWire(entry.Data) + if err != nil { + // A malformed entry would be a protocol violation; skip it rather + // than corrupting the log. + continue + } + s.log = append(s.log, internalEntry{topic: entry.Topic, payload: payload}) + } +} + +func (s *WorkflowStream) onPoll(ctx workflow.Context, input PollInput) (PollResult, error) { + // Wait until items at or after the requested offset are available, the + // requested offset has been truncated away, or the stream is draining. + // baseOffset can advance via Truncate while we wait, so re-evaluate the + // requested position against the current baseOffset on every check rather + // than capturing it once up front — otherwise a truncation that passes the + // waiting offset leaves the condition permanently unsatisfiable. + truncated := false + if err := workflow.Await(ctx, func() bool { + if s.draining { + return true + } + if input.FromOffset != 0 && input.FromOffset < s.baseOffset { + // The subscriber's position was truncated, possibly while waiting. + truncated = true + return true + } + // max clamps "from the beginning" to whatever is available. + logOffset := max(input.FromOffset-s.baseOffset, 0) + return int64(len(s.log)) > logOffset + }); err != nil { + return PollResult{}, err + } + if truncated { + return PollResult{}, temporal.NewNonRetryableApplicationError( + fmt.Sprintf("requested offset %d has been truncated; current base offset is %d", input.FromOffset, s.baseOffset), + ErrTypeTruncatedOffset, nil) + } + + // max clamps "From the beginning" to whatever is available. + logOffset := max(input.FromOffset-s.baseOffset, 0) + + var topicSet map[string]struct{} + if len(input.Topics) > 0 { + topicSet = make(map[string]struct{}, len(input.Topics)) + for _, t := range input.Topics { + topicSet[t] = struct{}{} + } + } + + wireItems := make([]WireItem, 0) + size := 0 + moreReady := false + nextOffset := s.baseOffset + int64(len(s.log)) + + for i := logOffset; i < int64(len(s.log)); i++ { + entry := s.log[i] + if topicSet != nil { + if _, ok := topicSet[entry.topic]; !ok { + continue + } + } + globalOffset := s.baseOffset + i + encoded, err := encodePayloadWire(entry.payload) + if err != nil { + return PollResult{}, err + } + itemSize := payloadWireSize(encoded, entry.topic) + if size+itemSize > maxPollResponseBytes && len(wireItems) > 0 { + // Resume from this item on the next poll. + nextOffset = globalOffset + moreReady = true + break + } + size += itemSize + wireItems = append(wireItems, WireItem{Topic: entry.topic, Data: encoded, Offset: globalOffset}) + } + + return PollResult{Items: wireItems, NextOffset: nextOffset, MoreReady: moreReady}, nil +} + +// WorkflowTopicHandle publishes to a single topic from workflow code. +type WorkflowTopicHandle struct { + name string + stream *WorkflowStream +} + +// Name returns the topic name. +func (h *WorkflowTopicHandle) Name() string { return h.name } + +// Publish appends value to the stream on this topic. value is serialized by the +// stream's PayloadConverters (see WithPayloadConverters), defaulting to the +// standard set; a pre-built *commonpb.Payload bypasses conversion. +func (h *WorkflowTopicHandle) Publish(value any) error { + payload, ok := value.(*commonpb.Payload) + if !ok { + var err error + payload, err = h.stream.dc.ToPayload(value) + if err != nil { + return fmt.Errorf("workflowstreams: convert value: %w", err) + } + } + h.stream.log = append(h.stream.log, internalEntry{topic: h.name, payload: payload}) + return nil +} diff --git a/contrib/workflowstreams/workflow_test.go b/contrib/workflowstreams/workflow_test.go new file mode 100644 index 000000000..4a7f6f6cd --- /dev/null +++ b/contrib/workflowstreams/workflow_test.go @@ -0,0 +1,177 @@ +package workflowstreams + +import ( + "testing" + "time" + + "github.com/stretchr/testify/require" + "go.temporal.io/sdk/converter" + "go.temporal.io/sdk/testsuite" + "go.temporal.io/sdk/workflow" +) + +// streamHostWorkflow hosts a WorkflowStream and runs until it receives a +// "finish" signal. priorState may be nil. +func streamHostWorkflow(ctx workflow.Context, priorState *WorkflowStreamState) error { + _, err := NewWorkflowStream(ctx, priorState) + if err != nil { + return err + } + + finished := false + workflow.Go(ctx, func(ctx workflow.Context) { + workflow.GetSignalChannel(ctx, "finish").Receive(ctx, nil) + finished = true + }) + + return workflow.Await(ctx, func() bool { return finished }) +} + +type topicVal struct { + topic string + value any +} + +func mustPublishInput(t *testing.T, publisherID string, seq int64, entries ...topicVal) PublishInput { + t.Helper() + in := PublishInput{PublisherID: publisherID, Sequence: seq} + for _, e := range entries { + payload, err := converter.GetDefaultDataConverter().ToPayload(e.value) + require.NoError(t, err) + data, err := encodePayloadWire(payload) + require.NoError(t, err) + in.Items = append(in.Items, PublishEntry{Topic: e.topic, Data: data}) + } + return in +} + +// byteOnlyPublishWorkflow restricts the stream to the byte-slice converter and +// returns whether publishing a string failed — it should, since that set has no +// converter for strings, whereas the default set's JSON fallback would accept +// it. A []byte must still publish cleanly. +func byteOnlyPublishWorkflow(ctx workflow.Context) (bool, error) { + stream, err := NewWorkflowStream(ctx, nil, WithPayloadConverters(converter.NewByteSlicePayloadConverter())) + if err != nil { + return false, err + } + if err := stream.Topic("events").Publish([]byte("hi")); err != nil { + return false, err + } + return stream.Topic("events").Publish("not-bytes") != nil, nil +} + +func TestWorkflowPublishUsesConfiguredConverters(t *testing.T) { + var ts testsuite.WorkflowTestSuite + env := ts.NewTestWorkflowEnvironment() + + env.ExecuteWorkflow(byteOnlyPublishWorkflow) + require.True(t, env.IsWorkflowCompleted()) + require.NoError(t, env.GetWorkflowError()) + + var stringPublishFailed bool + require.NoError(t, env.GetWorkflowResult(&stringPublishFailed)) + require.True(t, stringPublishFailed, + "a string is unconvertible under the byte-slice-only set, proving WithPayloadConverters drives conversion") +} + +func TestExternalPublishAndOffsetQuery(t *testing.T) { + var ts testsuite.WorkflowTestSuite + env := ts.NewTestWorkflowEnvironment() + + var offset int64 + env.RegisterDelayedCallback(func() { + env.SignalWorkflow(PublishSignalName, mustPublishInput(t, "pub1", 1, + topicVal{"events", "a"}, topicVal{"events", "b"})) + }, time.Millisecond) + env.RegisterDelayedCallback(func() { + val, err := env.QueryWorkflow(OffsetQueryName) + require.NoError(t, err) + require.NoError(t, val.Get(&offset)) + }, 2*time.Millisecond) + env.RegisterDelayedCallback(func() { + env.SignalWorkflow("finish", nil) + }, 3*time.Millisecond) + + env.ExecuteWorkflow(streamHostWorkflow, (*WorkflowStreamState)(nil)) + require.True(t, env.IsWorkflowCompleted()) + require.NoError(t, env.GetWorkflowError()) + require.EqualValues(t, 2, offset) +} + +func TestPublisherDedup(t *testing.T) { + var ts testsuite.WorkflowTestSuite + env := ts.NewTestWorkflowEnvironment() + + var offset int64 + env.RegisterDelayedCallback(func() { + env.SignalWorkflow(PublishSignalName, mustPublishInput(t, "pub1", 1, topicVal{"events", "a"})) + }, time.Millisecond) + env.RegisterDelayedCallback(func() { + // Same publisher + sequence: must be dropped. + env.SignalWorkflow(PublishSignalName, mustPublishInput(t, "pub1", 1, topicVal{"events", "dup"})) + }, 2*time.Millisecond) + env.RegisterDelayedCallback(func() { + env.SignalWorkflow(PublishSignalName, mustPublishInput(t, "pub1", 2, topicVal{"events", "c"})) + }, 3*time.Millisecond) + env.RegisterDelayedCallback(func() { + val, err := env.QueryWorkflow(OffsetQueryName) + require.NoError(t, err) + require.NoError(t, val.Get(&offset)) + }, 4*time.Millisecond) + env.RegisterDelayedCallback(func() { + env.SignalWorkflow("finish", nil) + }, 5*time.Millisecond) + + env.ExecuteWorkflow(streamHostWorkflow, (*WorkflowStreamState)(nil)) + require.True(t, env.IsWorkflowCompleted()) + require.NoError(t, env.GetWorkflowError()) + require.EqualValues(t, 2, offset, "duplicate batch should be dropped") +} + +func TestPollReturnsItemsWithTopicFilter(t *testing.T) { + var ts testsuite.WorkflowTestSuite + env := ts.NewTestWorkflowEnvironment() + + var result PollResult + var pollErr error + env.RegisterDelayedCallback(func() { + env.SignalWorkflow(PublishSignalName, mustPublishInput(t, "pub1", 1, + topicVal{"a", "1"}, topicVal{"b", "2"}, topicVal{"a", "3"})) + }, time.Millisecond) + env.RegisterDelayedCallback(func() { + env.UpdateWorkflow(PollUpdateName, "poll1", &testsuite.TestUpdateCallback{ + OnAccept: func() {}, + OnReject: func(err error) { pollErr = err }, + OnComplete: func(success any, err error) { + if err != nil { + pollErr = err + return + } + result = success.(PollResult) + }, + }, PollInput{Topics: []string{"a"}, FromOffset: 0}) + }, 2*time.Millisecond) + env.RegisterDelayedCallback(func() { + env.SignalWorkflow("finish", nil) + }, 3*time.Millisecond) + + env.ExecuteWorkflow(streamHostWorkflow, (*WorkflowStreamState)(nil)) + require.True(t, env.IsWorkflowCompleted()) + require.NoError(t, env.GetWorkflowError()) + require.NoError(t, pollErr) + + // Only topic "a" items, with global offsets 0 and 2. + require.Len(t, result.Items, 2) + require.Equal(t, "a", result.Items[0].Topic) + require.EqualValues(t, 0, result.Items[0].Offset) + require.Equal(t, "a", result.Items[1].Topic) + require.EqualValues(t, 2, result.Items[1].Offset) + require.EqualValues(t, 3, result.NextOffset) + require.False(t, result.MoreReady) + + payload, err := decodePayloadWire(result.Items[1].Data) + require.NoError(t, err) + var v string + require.NoError(t, converter.GetDefaultDataConverter().FromPayload(payload, &v)) + require.Equal(t, "3", v) +} diff --git a/internal/internal_utils.go b/internal/internal_utils.go index ccba66ea3..9dfefdeeb 100644 --- a/internal/internal_utils.go +++ b/internal/internal_utils.go @@ -7,6 +7,7 @@ import ( "fmt" "os" "os/signal" + "strings" "sync" "syscall" "time" @@ -33,6 +34,10 @@ const ( temporalPrefixError = "__temporal_ is a reserved prefix" ) +func isWorkflowStreamReservedName(name string) bool { + return strings.HasPrefix(name, "__temporal_workflow_stream_") +} + // grpcContextBuilder stores all gRPC-specific parameters that will // be stored inside of a context. type grpcContextBuilder struct { diff --git a/internal/workflow.go b/internal/workflow.go index b11754d20..df1036f66 100644 --- a/internal/workflow.go +++ b/internal/workflow.go @@ -2200,7 +2200,7 @@ func (wc *workflowEnvironmentInterceptor) GetSignalChannelWithOptions( signalName string, options SignalChannelOptions, ) ReceiveChannel { - if strings.HasPrefix(signalName, temporalPrefix) { + if strings.HasPrefix(signalName, temporalPrefix) && !isWorkflowStreamReservedName(signalName) { panic(temporalPrefixError) } eo := getWorkflowEnvOptions(ctx) @@ -2529,7 +2529,7 @@ func (wc *workflowEnvironmentInterceptor) SetQueryHandlerWithOptions( handler interface{}, options QueryHandlerOptions, ) error { - if strings.HasPrefix(queryType, "__") { + if strings.HasPrefix(queryType, "__") && !isWorkflowStreamReservedName(queryType) { return errors.New("queryType starts with '__' is reserved for internal use") } return setQueryHandler(ctx, queryType, handler, options) @@ -2566,7 +2566,7 @@ func SetUpdateHandler(ctx Context, updateName string, handler interface{}, opts } func (wc *workflowEnvironmentInterceptor) SetUpdateHandler(ctx Context, name string, handler interface{}, opts UpdateHandlerOptions) error { - if strings.HasPrefix(name, "__") { + if strings.HasPrefix(name, "__") && !isWorkflowStreamReservedName(name) { return errors.New("update names starting with '__' are reserved for internal use") } return setUpdateHandler(ctx, name, handler, opts) diff --git a/internal/workflow_stream_reserved_name_test.go b/internal/workflow_stream_reserved_name_test.go new file mode 100644 index 000000000..dd674c4bd --- /dev/null +++ b/internal/workflow_stream_reserved_name_test.go @@ -0,0 +1,86 @@ +package internal + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestIsWorkflowStreamReservedName(t *testing.T) { + require.True(t, isWorkflowStreamReservedName("__temporal_workflow_stream_publish")) + require.True(t, isWorkflowStreamReservedName("__temporal_workflow_stream_poll")) + require.True(t, isWorkflowStreamReservedName("__temporal_workflow_stream_offset")) + require.False(t, isWorkflowStreamReservedName("__temporal_")) + require.False(t, isWorkflowStreamReservedName("__temporal_foo")) + require.False(t, isWorkflowStreamReservedName("__internal")) + require.False(t, isWorkflowStreamReservedName("events")) +} + +// TestIsWorkflowStreamReservedNameAllowsHandlers verifies that the +// "__temporal_workflow_stream_" sub-namespace is permitted for signal, update, +// and query handler registration even though the "__temporal_"/"__" prefixes +// are otherwise reserved. This backs the workflowstreams contrib package. +func TestIsWorkflowStreamReservedNameAllowsHandlers(t *testing.T) { + var ts WorkflowTestSuite + env := ts.NewTestWorkflowEnvironment() + + wf := func(ctx Context) error { + // Signal channel registration must not panic on the workflow-stream name. + _ = GetSignalChannel(ctx, "__temporal_workflow_stream_publish") + if err := SetUpdateHandler(ctx, "__temporal_workflow_stream_poll", + func(ctx Context) error { return nil }, UpdateHandlerOptions{}); err != nil { + return err + } + return SetQueryHandler(ctx, "__temporal_workflow_stream_offset", + func() (int, error) { return 0, nil }) + } + + env.ExecuteWorkflow(wf) + require.True(t, env.IsWorkflowCompleted()) + require.NoError(t, env.GetWorkflowError()) +} + +// TestIsWorkflowStreamReservedNameStillRejectsOtherNames verifies that names outside the +// workflow-stream sub-namespace are still rejected. +func TestIsWorkflowStreamReservedNameStillRejectsOtherNames(t *testing.T) { + var ts WorkflowTestSuite + + t.Run("signal panics", func(t *testing.T) { + env := ts.NewTestWorkflowEnvironment() + wf := func(ctx Context) error { + _ = GetSignalChannel(ctx, "__temporal_other") + return nil + } + env.ExecuteWorkflow(wf) + require.True(t, env.IsWorkflowCompleted()) + require.Error(t, env.GetWorkflowError()) + }) + + t.Run("update rejected", func(t *testing.T) { + env := ts.NewTestWorkflowEnvironment() + var updateErr error + wf := func(ctx Context) error { + updateErr = SetUpdateHandler(ctx, "__temporal_other", + func(ctx Context) error { return nil }, UpdateHandlerOptions{}) + return nil + } + env.ExecuteWorkflow(wf) + require.True(t, env.IsWorkflowCompleted()) + require.NoError(t, env.GetWorkflowError()) + require.Error(t, updateErr) + }) + + t.Run("query rejected", func(t *testing.T) { + env := ts.NewTestWorkflowEnvironment() + var queryErr error + wf := func(ctx Context) error { + queryErr = SetQueryHandler(ctx, "__temporal_other", + func() (int, error) { return 0, nil }) + return nil + } + env.ExecuteWorkflow(wf) + require.True(t, env.IsWorkflowCompleted()) + require.NoError(t, env.GetWorkflowError()) + require.Error(t, queryErr) + }) +}