From 468f407edf69afde22290a4e649f4df9b1ed3b91 Mon Sep 17 00:00:00 2001 From: bao <1727283040@qq.com> Date: Sun, 28 Jun 2026 12:33:58 +0800 Subject: [PATCH] feat: implement retry logic for before-request errors Adds functionality to retry requests when encountering errors during the before-request phase. Introduces methods to determine if a retry should occur based on custom conditions and hooks. Enhances the existing retry mechanism to handle temporary errors effectively, ensuring robust request handling. --- request.go | 93 +++++++++++++++++++++++++++++++++------------------ retry_test.go | 60 +++++++++++++++++++++++++++++++++ 2 files changed, 120 insertions(+), 33 deletions(-) diff --git a/request.go b/request.go index d7695a32..b5ac8ec3 100644 --- a/request.go +++ b/request.go @@ -663,6 +663,60 @@ func (r *Request) Do(ctx ...context.Context) *Response { return resp } +func (r *Request) shouldRetry(resp *Response, err error) bool { + if errors.Is(err, context.Canceled) || r.retryOption == nil || + (r.RetryAttempt >= r.retryOption.MaxRetries && r.retryOption.MaxRetries >= 0) { + return false + } + needRetry := err != nil + if l := len(r.retryOption.RetryConditions); l > 0 { + for i := l - 1; i >= 0; i-- { + needRetry = r.retryOption.RetryConditions[i](resp, err) + if needRetry { + break + } + } + } + return needRetry +} + +func (r *Request) prepareRetry(resp *Response, err error) { + r.RetryAttempt++ + if l := len(r.retryOption.RetryHooks); l > 0 { + for i := l - 1; i >= 0; i-- { + r.retryOption.RetryHooks[i](resp, err) + } + } + time.Sleep(r.retryOption.GetRetryInterval(resp, r.RetryAttempt)) + + if r.dumpBuffer != nil { + r.dumpBuffer.Reset() + } + if r.trace != nil { + r.trace = &clientTrace{} + } + if resp != nil { + resp.body = nil + resp.result = nil + resp.error = nil + } +} + +// tryRetry attempts retry after an error. Returns true if the request loop should continue. +func (r *Request) tryRetry(resp **Response, err error) bool { + if *resp == nil { + *resp = &Response{Request: r} + } + if err != nil { + (*resp).Err = err + } + if !r.shouldRetry(*resp, err) { + return false + } + r.prepareRetry(*resp, err) + return true +} + func (r *Request) do() (resp *Response, err error) { defer func() { if resp == nil { @@ -673,12 +727,16 @@ func (r *Request) do() (resp *Response, err error) { } }() +retry: for { if r.Headers == nil { r.Headers = make(http.Header) } for _, f := range r.client.udBeforeRequest { if err = f(r.client, r); err != nil { + if r.tryRetry(&resp, err) { + continue retry + } return } } @@ -704,43 +762,12 @@ func (r *Request) do() (resp *Response, err error) { } } - if contextCanceled || r.retryOption == nil || (r.RetryAttempt >= r.retryOption.MaxRetries && r.retryOption.MaxRetries >= 0) { // absolutely cannot retry. + if contextCanceled { return } - - // check retry whether is needed. - needRetry := err != nil // default behaviour: retry if error occurs - if l := len(r.retryOption.RetryConditions); l > 0 { // override default behaviour if custom RetryConditions has been set. - for i := l - 1; i >= 0; i-- { - needRetry = r.retryOption.RetryConditions[i](resp, err) - if needRetry { - break - } - } - } - if !needRetry { // no retry is needed. + if !r.tryRetry(&resp, err) { return } - - // need retry, attempt to retry - r.RetryAttempt++ - if l := len(r.retryOption.RetryHooks); l > 0 { - for i := l - 1; i >= 0; i-- { // run retry hooks in reverse order - r.retryOption.RetryHooks[i](resp, err) - } - } - time.Sleep(r.retryOption.GetRetryInterval(resp, r.RetryAttempt)) - - // clean up before retry - if r.dumpBuffer != nil { - r.dumpBuffer.Reset() - } - if r.trace != nil { - r.trace = &clientTrace{} - } - resp.body = nil - resp.result = nil - resp.error = nil } } diff --git a/retry_test.go b/retry_test.go index 5814b5fe..1b843b29 100644 --- a/retry_test.go +++ b/retry_test.go @@ -2,6 +2,7 @@ package req import ( "bytes" + "errors" "io" "math" "net/http" @@ -161,6 +162,65 @@ func TestRetryWithModify(t *testing.T) { tests.AssertEqual(t, 2, resp.Request.RetryAttempt) } +func TestRetryOnBeforeRequestError(t *testing.T) { + failCount := 0 + retryHookCount := 0 + c := tc().OnBeforeRequest(func(client *Client, request *Request) error { + failCount++ + if failCount < 3 { + return errors.New("temporary before-request error") + } + return nil + }) + resp, err := c.R(). + SetRetryCount(2). + SetRetryFixedInterval(1 * time.Millisecond). + SetRetryCondition(func(resp *Response, err error) bool { + return err != nil && err.Error() == "temporary before-request error" + }). + SetRetryHook(func(resp *Response, err error) { + retryHookCount++ + }). + Get("/") + assertSuccess(t, resp, err) + tests.AssertEqual(t, 2, resp.Request.RetryAttempt) + tests.AssertEqual(t, 2, retryHookCount) + tests.AssertEqual(t, 3, failCount) +} + +func TestRetryConditionHasRequestOnBeforeRequestError(t *testing.T) { + c := tc().OnBeforeRequest(func(client *Client, request *Request) error { + return errors.New("before-request error") + }) + resp, err := c.R(). + SetRetryCount(0). + SetRetryCondition(func(resp *Response, err error) bool { + tests.AssertNotNil(t, resp) + tests.AssertNotNil(t, resp.Request) + tests.AssertEqual(t, "/header", resp.Request.RawURL) + tests.AssertEqual(t, http.MethodGet, resp.Request.Method) + tests.AssertEqual(t, err, resp.Err) + return false + }). + Get("/header") + tests.AssertNotNil(t, err) + tests.AssertNotNil(t, resp.Request) +} + +func TestNoRetryOnBeforeRequestErrorWhenConditionFalse(t *testing.T) { + c := tc().OnBeforeRequest(func(client *Client, request *Request) error { + return errors.New("not retryable") + }) + resp, err := c.R(). + SetRetryCount(3). + SetRetryCondition(func(resp *Response, err error) bool { + return false + }). + Get("/") + tests.AssertNotNil(t, err) + tests.AssertEqual(t, 0, resp.Request.RetryAttempt) +} + func TestRetryFalse(t *testing.T) { resp, err := tc().SetTimeout(2 * time.Second).R(). SetRetryCount(1).