@@ -2,20 +2,20 @@ package apiserversdk
22
33import (
44 "bytes"
5+ "context"
56 "fmt"
67 "io"
78 "net/http"
89 "net/http/httputil"
910 "net/url"
1011 "strings"
11- "time"
1212
1313 metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
1414 "k8s.io/apimachinery/pkg/util/net"
1515 "k8s.io/client-go/kubernetes"
1616 "k8s.io/client-go/rest"
1717
18- apiserverutil "github.com/ray-project/kuberay/apiserversdk/util"
18+ apiserversdkutil "github.com/ray-project/kuberay/apiserversdk/util"
1919 rayutil "github.com/ray-project/kuberay/ray-operator/controllers/ray/utils"
2020)
2121
@@ -95,30 +95,33 @@ func requireKubeRayService(handler http.Handler, k8sClient *kubernetes.Clientset
9595// retryRoundTripper is a custom implementation of http.RoundTripper that retries HTTP requests.
9696// It verifies retryable HTTP status codes and retries using exponential backoff.
9797type retryRoundTripper struct {
98- base http.RoundTripper
99-
100- // Num of retries after the initial attempt
101- maxRetries int
102-
103- // Retry backoff settings
104- initBackoff time.Duration
105- backoffBase float64
106- maxBackoff time.Duration
98+ base http.RoundTripper
99+ retryCfg apiserversdkutil.RetryConfig
107100}
108101
109102func newRetryRoundTripper (base http.RoundTripper ) http.RoundTripper {
103+ retryCfg := apiserversdkutil.RetryConfig {
104+ MaxRetry : apiserversdkutil .HTTPClientDefaultMaxRetry ,
105+ BackoffFactor : apiserversdkutil .HTTPClientDefaultBackoffFactor ,
106+ InitBackoff : apiserversdkutil .HTTPClientDefaultInitBackoff ,
107+ MaxBackoff : apiserversdkutil .HTTPClientDefaultMaxBackoff ,
108+ OverallTimeout : apiserversdkutil .HTTPClientDefaultOverallTimeout ,
109+ }
110+
110111 return & retryRoundTripper {
111- base : base ,
112- maxRetries : apiserverutil .HTTPClientDefaultMaxRetry ,
113- initBackoff : apiserverutil .HTTPClientDefaultInitBackoff ,
114- backoffBase : apiserverutil .HTTPClientDefaultBackoffBase ,
115- maxBackoff : apiserverutil .HTTPClientDefaultMaxBackoff ,
112+ base : base ,
113+ retryCfg : retryCfg ,
116114 }
117115}
118116
119117func (rrt * retryRoundTripper ) RoundTrip (req * http.Request ) (* http.Response , error ) {
120118 ctx := req .Context ()
121119
120+ ctx , cancel := context .WithTimeout (ctx , rrt .retryCfg .OverallTimeout )
121+ defer cancel ()
122+
123+ req = req .WithContext (ctx )
124+
122125 var bodyBytes []byte
123126 var resp * http.Response
124127 var err error
@@ -135,8 +138,8 @@ func (rrt *retryRoundTripper) RoundTrip(req *http.Request) (*http.Response, erro
135138 }
136139 }
137140
138- for attempt := 0 ; attempt <= rrt .maxRetries ; attempt ++ {
139- /* Try up to (rrt.maxRetries + 1) times: initial attempt + retries */
141+ for attempt := 0 ; attempt <= rrt .retryCfg . MaxRetry ; attempt ++ {
142+ /* Try up to (rrt.retryCfg.MaxRetry + 1) times: initial attempt + retries */
140143
141144 if bodyBytes != nil {
142145 req .Body = io .NopCloser (bytes .NewReader (bodyBytes ))
@@ -147,15 +150,15 @@ func (rrt *retryRoundTripper) RoundTrip(req *http.Request) (*http.Response, erro
147150 return resp , fmt .Errorf ("request to %s %s failed with error: %w" , req .Method , req .URL .String (), err )
148151 }
149152
150- if apiserverutil .IsSuccessfulStatusCode (resp .StatusCode ) {
153+ if apiserversdkutil .IsSuccessfulStatusCode (resp .StatusCode ) {
151154 return resp , nil
152155 }
153156
154- if ! apiserverutil .IsRetryableHTTPStatusCodes (resp .StatusCode ) {
157+ if ! apiserversdkutil .IsRetryableHTTPStatusCodes (resp .StatusCode ) {
155158 return resp , nil
156159 }
157160
158- if attempt == rrt .maxRetries {
161+ if attempt == rrt .retryCfg . MaxRetry {
159162 return resp , nil
160163 }
161164
@@ -169,20 +172,14 @@ func (rrt *retryRoundTripper) RoundTrip(req *http.Request) (*http.Response, erro
169172 }
170173 }
171174
172- sleepDuration := apiserverutil .GetRetryBackoff (attempt , rrt .initBackoff , rrt .backoffBase , rrt .maxBackoff )
175+ sleepDuration := apiserversdkutil .GetRetryBackoff (attempt , rrt .retryCfg . InitBackoff , rrt .retryCfg . BackoffFactor , rrt .retryCfg . MaxBackoff )
173176
174- // TODO: merge common utils for apiserver v1 and v2
175- if deadline , ok := ctx .Deadline (); ok {
176- remaining := time .Until (deadline )
177- if sleepDuration > remaining {
178- return resp , fmt .Errorf ("retry timeout exceeded context deadline" )
179- }
177+ if ok := apiserversdkutil .CheckContextDeadline (ctx , sleepDuration ); ! ok {
178+ return resp , fmt .Errorf ("retry timeout exceeded context deadline" )
180179 }
181180
182- select {
183- case <- time .After (sleepDuration ):
184- case <- ctx .Done ():
185- return resp , fmt .Errorf ("retry canceled during backoff: %w" , ctx .Err ())
181+ if err = apiserversdkutil .Sleep (ctx , sleepDuration ); err != nil {
182+ return resp , fmt .Errorf ("retry canceled during backoff: %w" , err )
186183 }
187184 }
188185 return resp , err
0 commit comments