Skip to content

[SPARK-52673][CONNECT][CLIENT] Add grpc RetryInfo handling to Spark Connect retry policies #51363

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 81 additions & 8 deletions python/pyspark/sql/connect/client/retries.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@
import random
import time
import typing
from typing import Optional, Callable, Generator, List, Type
from google.rpc import error_details_pb2
from grpc_status import rpc_status
from typing import Optional, Callable, Generator, List, Type, cast
from types import TracebackType
from pyspark.sql.connect.logging import logger
from pyspark.errors import PySparkRuntimeError, RetriesExceeded
Expand All @@ -45,6 +47,34 @@ class RetryPolicy:
Describes key aspects of RetryPolicy.

It's advised that different policies are implemented as different subclasses.

Parameters
----------
max_retries: int, optional
Maximum number of retries.
initial_backoff: int
Start value of the exponential backoff.
max_backoff: int, optional
Maximal value of the exponential backoff.
backoff_multiplier: float
Multiplicative base of the exponential backoff.
jitter: int
Sample a random value uniformly from the range [0, jitter] and add it to the backoff.
min_jitter_threshold: int
Minimal value of the backoff to add random jitter.
recognize_server_retry_delay: bool
Per gRPC standard, the server can send error messages that contain `RetryInfo` message
with `retry_delay` field indicating that the client should wait for at least `retry_delay`
amount of time before retrying again, see:
https://github.com/googleapis/googleapis/blob/master/google/rpc/error_details.proto#L91

If this flag is set to true, RetryPolicy will use `RetryInfo.retry_delay` field
in the backoff computation. Server's `retry_delay` can override client's `max_backoff`.

This flag does not change which errors are retried, only how the backoff is computed.
`DefaultPolicy` additionally has a rule for retrying any error that contains `RetryInfo`.
max_server_retry_delay: int, optional
Limit for the server-provided `retry_delay`.
"""

def __init__(
Expand All @@ -55,13 +85,17 @@ def __init__(
backoff_multiplier: float = 1.0,
jitter: int = 0,
min_jitter_threshold: int = 0,
recognize_server_retry_delay: bool = False,
max_server_retry_delay: Optional[int] = None,
):
self.max_retries = max_retries
self.initial_backoff = initial_backoff
self.max_backoff = max_backoff
self.backoff_multiplier = backoff_multiplier
self.jitter = jitter
self.min_jitter_threshold = min_jitter_threshold
self.recognize_server_retry_delay = recognize_server_retry_delay
self.max_server_retry_delay = max_server_retry_delay
self._name = self.__class__.__name__

@property
Expand Down Expand Up @@ -98,7 +132,7 @@ def name(self) -> str:
def can_retry(self, exception: BaseException) -> bool:
return self.policy.can_retry(exception)

def next_attempt(self) -> Optional[int]:
def next_attempt(self, exception: Optional[BaseException] = None) -> Optional[int]:
"""
Returns
-------
Expand All @@ -119,6 +153,14 @@ def next_attempt(self) -> Optional[int]:
float(self.policy.max_backoff), wait_time * self.policy.backoff_multiplier
)

if exception is not None and self.policy.recognize_server_retry_delay:
retry_delay = extract_retry_delay(exception)
if retry_delay is not None:
logger.debug(f"The server has sent a retry delay of {retry_delay} ms.")
if self.policy.max_server_retry_delay is not None:
retry_delay = min(retry_delay, self.policy.max_server_retry_delay)
wait_time = max(wait_time, retry_delay)

# Jitter current backoff, after the future backoff was computed
if wait_time >= self.policy.min_jitter_threshold:
wait_time += random.uniform(0, self.policy.jitter)
Expand Down Expand Up @@ -160,6 +202,7 @@ class Retrying:
This class is a point of entry into the retry logic.
The class accepts a list of retry policies and applies them in given order.
The first policy accepting an exception will be used.
If the error was matched by one policy, the other policies will be skipped.

The usage of the class should be as follows:
for attempt in Retrying(...):
Expand Down Expand Up @@ -217,17 +260,18 @@ def _wait(self) -> None:
return

# Attempt to find a policy to wait with
matched_policy = None
for policy in self._policies:
if not policy.can_retry(exception):
continue

wait_time = policy.next_attempt()
if policy.can_retry(exception):
matched_policy = policy
break
if matched_policy is not None:
wait_time = matched_policy.next_attempt(exception)
if wait_time is not None:
logger.debug(
f"Got error: {repr(exception)}. "
+ f"Will retry after {wait_time} ms (policy: {policy.name})"
+ f"Will retry after {wait_time} ms (policy: {matched_policy.name})"
)

self._sleep(wait_time / 1000)
return

Expand Down Expand Up @@ -274,6 +318,8 @@ def __init__(
max_backoff: Optional[int] = 60000,
jitter: int = 500,
min_jitter_threshold: int = 2000,
recognize_server_retry_delay: bool = True,
max_server_retry_delay: Optional[int] = 10 * 60 * 1000, # 10 minutes
):
super().__init__(
max_retries=max_retries,
Expand All @@ -282,6 +328,8 @@ def __init__(
max_backoff=max_backoff,
jitter=jitter,
min_jitter_threshold=min_jitter_threshold,
recognize_server_retry_delay=recognize_server_retry_delay,
max_server_retry_delay=max_server_retry_delay,
)

def can_retry(self, e: BaseException) -> bool:
Expand Down Expand Up @@ -314,4 +362,29 @@ def can_retry(self, e: BaseException) -> bool:
if e.code() == grpc.StatusCode.UNAVAILABLE:
return True

if extract_retry_info(e) is not None:
# All errors messages containing `RetryInfo` should be retried.
return True

return False


def extract_retry_info(exception: BaseException) -> Optional[error_details_pb2.RetryInfo]:
"""Extract and return RetryInfo from the grpc.RpcError"""
if isinstance(exception, grpc.RpcError):
status = rpc_status.from_call(cast(grpc.Call, exception))
if status:
for d in status.details:
if d.Is(error_details_pb2.RetryInfo.DESCRIPTOR):
info = error_details_pb2.RetryInfo()
d.Unpack(info)
return info
return None


def extract_retry_delay(exception: BaseException) -> Optional[int]:
"""Extract and return RetryInfo.retry_delay in milliseconds from grpc.RpcError if present."""
retry_info = extract_retry_info(exception)
if retry_info is not None:
return retry_info.retry_delay.ToMilliseconds()
return None
31 changes: 1 addition & 30 deletions python/pyspark/sql/tests/connect/client/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
)
from pyspark.sql.connect.client.reattach import ExecutePlanResponseReattachableIterator
from pyspark.sql.connect.session import SparkSession as RemoteSparkSession
from pyspark.errors import PySparkRuntimeError, RetriesExceeded
from pyspark.errors import PySparkRuntimeError
import pyspark.sql.connect.proto as proto

class TestPolicy(DefaultPolicy):
Expand Down Expand Up @@ -227,35 +227,6 @@ def test_is_closed(self):
client.close()
self.assertTrue(client.is_closed)

def test_retry(self):
client = SparkConnectClient("sc://foo/;token=bar")

total_sleep = 0

def sleep(t):
nonlocal total_sleep
total_sleep += t

try:
for attempt in Retrying(client._retry_policies, sleep=sleep):
with attempt:
raise TestException("Retryable error", grpc.StatusCode.UNAVAILABLE)
except RetriesExceeded:
pass

# tolerated at least 10 mins of fails
self.assertGreaterEqual(total_sleep, 600)

def test_retry_client_unit(self):
client = SparkConnectClient("sc://foo/;token=bar")

policyA = TestPolicy()
policyB = DefaultPolicy()

client.set_retry_policies([policyA, policyB])

self.assertEqual(client.get_retry_policies(), [policyA, policyB])

def test_channel_builder_with_session(self):
dummy = str(uuid.uuid4())
chan = DefaultChannelBuilder(f"sc://foo/;session_id={dummy}")
Expand Down
Loading