19
19
import random
20
20
import time
21
21
import typing
22
- from typing import Optional , Callable , Generator , List , Type
22
+ from google .rpc import error_details_pb2
23
+ from grpc_status import rpc_status
24
+ from typing import Optional , Callable , Generator , List , Type , cast
23
25
from types import TracebackType
24
26
from pyspark .sql .connect .logging import logger
25
27
from pyspark .errors import PySparkRuntimeError , RetriesExceeded
@@ -45,6 +47,34 @@ class RetryPolicy:
45
47
Describes key aspects of RetryPolicy.
46
48
47
49
It's advised that different policies are implemented as different subclasses.
50
+
51
+ Parameters
52
+ ----------
53
+ max_retries: int, optional
54
+ Maximum number of retries.
55
+ initial_backoff: int
56
+ Start value of the exponential backoff.
57
+ max_backoff: int, optional
58
+ Maximal value of the exponential backoff.
59
+ backoff_multiplier: float
60
+ Multiplicative base of the exponential backoff.
61
+ jitter: int
62
+ Sample a random value uniformly from the range [0, jitter] and add it to the backoff.
63
+ min_jitter_threshold: int
64
+ Minimal value of the backoff to add random jitter.
65
+ recognize_server_retry_delay: bool
66
+ Per gRPC standard, the server can send error messages that contain `RetryInfo` message
67
+ with `retry_delay` field indicating that the client should wait for at least `retry_delay`
68
+ amount of time before retrying again, see:
69
+ https://github.com/googleapis/googleapis/blob/master/google/rpc/error_details.proto#L91
70
+
71
+ If this flag is set to true, RetryPolicy will use `RetryInfo.retry_delay` field
72
+ in the backoff computation. Server's `retry_delay` can override client's `max_backoff`.
73
+
74
+ This flag does not change which errors are retried, only how the backoff is computed.
75
+ `DefaultPolicy` additionally has a rule for retrying any error that contains `RetryInfo`.
76
+ max_server_retry_delay: int, optional
77
+ Limit for the server-provided `retry_delay`.
48
78
"""
49
79
50
80
def __init__ (
@@ -55,13 +85,17 @@ def __init__(
55
85
backoff_multiplier : float = 1.0 ,
56
86
jitter : int = 0 ,
57
87
min_jitter_threshold : int = 0 ,
88
+ recognize_server_retry_delay : bool = False ,
89
+ max_server_retry_delay : Optional [int ] = None ,
58
90
):
59
91
self .max_retries = max_retries
60
92
self .initial_backoff = initial_backoff
61
93
self .max_backoff = max_backoff
62
94
self .backoff_multiplier = backoff_multiplier
63
95
self .jitter = jitter
64
96
self .min_jitter_threshold = min_jitter_threshold
97
+ self .recognize_server_retry_delay = recognize_server_retry_delay
98
+ self .max_server_retry_delay = max_server_retry_delay
65
99
self ._name = self .__class__ .__name__
66
100
67
101
@property
@@ -98,7 +132,7 @@ def name(self) -> str:
98
132
def can_retry (self , exception : BaseException ) -> bool :
99
133
return self .policy .can_retry (exception )
100
134
101
- def next_attempt (self ) -> Optional [int ]:
135
+ def next_attempt (self , exception : Optional [ BaseException ] = None ) -> Optional [int ]:
102
136
"""
103
137
Returns
104
138
-------
@@ -119,6 +153,14 @@ def next_attempt(self) -> Optional[int]:
119
153
float (self .policy .max_backoff ), wait_time * self .policy .backoff_multiplier
120
154
)
121
155
156
+ if exception is not None and self .policy .recognize_server_retry_delay :
157
+ retry_delay = extract_retry_delay (exception )
158
+ if retry_delay is not None :
159
+ logger .debug (f"The server has sent a retry delay of { retry_delay } ms." )
160
+ if self .policy .max_server_retry_delay is not None :
161
+ retry_delay = min (retry_delay , self .policy .max_server_retry_delay )
162
+ wait_time = max (wait_time , retry_delay )
163
+
122
164
# Jitter current backoff, after the future backoff was computed
123
165
if wait_time >= self .policy .min_jitter_threshold :
124
166
wait_time += random .uniform (0 , self .policy .jitter )
@@ -160,6 +202,7 @@ class Retrying:
160
202
This class is a point of entry into the retry logic.
161
203
The class accepts a list of retry policies and applies them in given order.
162
204
The first policy accepting an exception will be used.
205
+ If the error was matched by one policy, the other policies will be skipped.
163
206
164
207
The usage of the class should be as follows:
165
208
for attempt in Retrying(...):
@@ -217,17 +260,18 @@ def _wait(self) -> None:
217
260
return
218
261
219
262
# Attempt to find a policy to wait with
263
+ matched_policy = None
220
264
for policy in self ._policies :
221
- if not policy .can_retry (exception ):
222
- continue
223
-
224
- wait_time = policy .next_attempt ()
265
+ if policy .can_retry (exception ):
266
+ matched_policy = policy
267
+ break
268
+ if matched_policy is not None :
269
+ wait_time = matched_policy .next_attempt (exception )
225
270
if wait_time is not None :
226
271
logger .debug (
227
272
f"Got error: { repr (exception )} . "
228
- + f"Will retry after { wait_time } ms (policy: { policy .name } )"
273
+ + f"Will retry after { wait_time } ms (policy: { matched_policy .name } )"
229
274
)
230
-
231
275
self ._sleep (wait_time / 1000 )
232
276
return
233
277
@@ -274,6 +318,8 @@ def __init__(
274
318
max_backoff : Optional [int ] = 60000 ,
275
319
jitter : int = 500 ,
276
320
min_jitter_threshold : int = 2000 ,
321
+ recognize_server_retry_delay : bool = True ,
322
+ max_server_retry_delay : Optional [int ] = 10 * 60 * 1000 , # 10 minutes
277
323
):
278
324
super ().__init__ (
279
325
max_retries = max_retries ,
@@ -282,6 +328,8 @@ def __init__(
282
328
max_backoff = max_backoff ,
283
329
jitter = jitter ,
284
330
min_jitter_threshold = min_jitter_threshold ,
331
+ recognize_server_retry_delay = recognize_server_retry_delay ,
332
+ max_server_retry_delay = max_server_retry_delay ,
285
333
)
286
334
287
335
def can_retry (self , e : BaseException ) -> bool :
@@ -314,4 +362,29 @@ def can_retry(self, e: BaseException) -> bool:
314
362
if e .code () == grpc .StatusCode .UNAVAILABLE :
315
363
return True
316
364
365
+ if extract_retry_info (e ) is not None :
366
+ # All errors messages containing `RetryInfo` should be retried.
367
+ return True
368
+
317
369
return False
370
+
371
+
372
+ def extract_retry_info (exception : BaseException ) -> Optional [error_details_pb2 .RetryInfo ]:
373
+ """Extract and return RetryInfo from the grpc.RpcError"""
374
+ if isinstance (exception , grpc .RpcError ):
375
+ status = rpc_status .from_call (cast (grpc .Call , exception ))
376
+ if status :
377
+ for d in status .details :
378
+ if d .Is (error_details_pb2 .RetryInfo .DESCRIPTOR ):
379
+ info = error_details_pb2 .RetryInfo ()
380
+ d .Unpack (info )
381
+ return info
382
+ return None
383
+
384
+
385
+ def extract_retry_delay (exception : BaseException ) -> Optional [int ]:
386
+ """Extract and return RetryInfo.retry_delay in milliseconds from grpc.RpcError if present."""
387
+ retry_info = extract_retry_info (exception )
388
+ if retry_info is not None :
389
+ return retry_info .retry_delay .ToMilliseconds ()
390
+ return None
0 commit comments