diff --git a/python/pyspark/sql/connect/client/retries.py b/python/pyspark/sql/connect/client/retries.py index e27100133b5ae..436da250d791e 100644 --- a/python/pyspark/sql/connect/client/retries.py +++ b/python/pyspark/sql/connect/client/retries.py @@ -19,7 +19,9 @@ import random import time import typing -from typing import Optional, Callable, Generator, List, Type +from google.rpc import error_details_pb2 +from grpc_status import rpc_status +from typing import Optional, Callable, Generator, List, Type, cast from types import TracebackType from pyspark.sql.connect.logging import logger from pyspark.errors import PySparkRuntimeError, RetriesExceeded @@ -45,6 +47,34 @@ class RetryPolicy: Describes key aspects of RetryPolicy. It's advised that different policies are implemented as different subclasses. + + Parameters + ---------- + max_retries: int, optional + Maximum number of retries. + initial_backoff: int + Start value of the exponential backoff. + max_backoff: int, optional + Maximal value of the exponential backoff. + backoff_multiplier: float + Multiplicative base of the exponential backoff. + jitter: int + Sample a random value uniformly from the range [0, jitter] and add it to the backoff. + min_jitter_threshold: int + Minimal value of the backoff to add random jitter. + recognize_server_retry_delay: bool + Per gRPC standard, the server can send error messages that contain `RetryInfo` message + with `retry_delay` field indicating that the client should wait for at least `retry_delay` + amount of time before retrying again, see: + https://github.com/googleapis/googleapis/blob/master/google/rpc/error_details.proto#L91 + + If this flag is set to true, RetryPolicy will use `RetryInfo.retry_delay` field + in the backoff computation. Server's `retry_delay` can override client's `max_backoff`. + + This flag does not change which errors are retried, only how the backoff is computed. + `DefaultPolicy` additionally has a rule for retrying any error that contains `RetryInfo`. + max_server_retry_delay: int, optional + Limit for the server-provided `retry_delay`. """ def __init__( @@ -55,6 +85,8 @@ def __init__( backoff_multiplier: float = 1.0, jitter: int = 0, min_jitter_threshold: int = 0, + recognize_server_retry_delay: bool = False, + max_server_retry_delay: Optional[int] = None, ): self.max_retries = max_retries self.initial_backoff = initial_backoff @@ -62,6 +94,8 @@ def __init__( self.backoff_multiplier = backoff_multiplier self.jitter = jitter self.min_jitter_threshold = min_jitter_threshold + self.recognize_server_retry_delay = recognize_server_retry_delay + self.max_server_retry_delay = max_server_retry_delay self._name = self.__class__.__name__ @property @@ -98,7 +132,7 @@ def name(self) -> str: def can_retry(self, exception: BaseException) -> bool: return self.policy.can_retry(exception) - def next_attempt(self) -> Optional[int]: + def next_attempt(self, exception: Optional[BaseException] = None) -> Optional[int]: """ Returns ------- @@ -119,6 +153,14 @@ def next_attempt(self) -> Optional[int]: float(self.policy.max_backoff), wait_time * self.policy.backoff_multiplier ) + if exception is not None and self.policy.recognize_server_retry_delay: + retry_delay = extract_retry_delay(exception) + if retry_delay is not None: + logger.debug(f"The server has sent a retry delay of {retry_delay} ms.") + if self.policy.max_server_retry_delay is not None: + retry_delay = min(retry_delay, self.policy.max_server_retry_delay) + wait_time = max(wait_time, retry_delay) + # Jitter current backoff, after the future backoff was computed if wait_time >= self.policy.min_jitter_threshold: wait_time += random.uniform(0, self.policy.jitter) @@ -160,6 +202,7 @@ class Retrying: This class is a point of entry into the retry logic. The class accepts a list of retry policies and applies them in given order. The first policy accepting an exception will be used. + If the error was matched by one policy, the other policies will be skipped. The usage of the class should be as follows: for attempt in Retrying(...): @@ -217,17 +260,18 @@ def _wait(self) -> None: return # Attempt to find a policy to wait with + matched_policy = None for policy in self._policies: - if not policy.can_retry(exception): - continue - - wait_time = policy.next_attempt() + if policy.can_retry(exception): + matched_policy = policy + break + if matched_policy is not None: + wait_time = matched_policy.next_attempt(exception) if wait_time is not None: logger.debug( f"Got error: {repr(exception)}. " - + f"Will retry after {wait_time} ms (policy: {policy.name})" + + f"Will retry after {wait_time} ms (policy: {matched_policy.name})" ) - self._sleep(wait_time / 1000) return @@ -274,6 +318,8 @@ def __init__( max_backoff: Optional[int] = 60000, jitter: int = 500, min_jitter_threshold: int = 2000, + recognize_server_retry_delay: bool = True, + max_server_retry_delay: Optional[int] = 10 * 60 * 1000, # 10 minutes ): super().__init__( max_retries=max_retries, @@ -282,6 +328,8 @@ def __init__( max_backoff=max_backoff, jitter=jitter, min_jitter_threshold=min_jitter_threshold, + recognize_server_retry_delay=recognize_server_retry_delay, + max_server_retry_delay=max_server_retry_delay, ) def can_retry(self, e: BaseException) -> bool: @@ -314,4 +362,29 @@ def can_retry(self, e: BaseException) -> bool: if e.code() == grpc.StatusCode.UNAVAILABLE: return True + if extract_retry_info(e) is not None: + # All errors messages containing `RetryInfo` should be retried. + return True + return False + + +def extract_retry_info(exception: BaseException) -> Optional[error_details_pb2.RetryInfo]: + """Extract and return RetryInfo from the grpc.RpcError""" + if isinstance(exception, grpc.RpcError): + status = rpc_status.from_call(cast(grpc.Call, exception)) + if status: + for d in status.details: + if d.Is(error_details_pb2.RetryInfo.DESCRIPTOR): + info = error_details_pb2.RetryInfo() + d.Unpack(info) + return info + return None + + +def extract_retry_delay(exception: BaseException) -> Optional[int]: + """Extract and return RetryInfo.retry_delay in milliseconds from grpc.RpcError if present.""" + retry_info = extract_retry_info(exception) + if retry_info is not None: + return retry_info.retry_delay.ToMilliseconds() + return None diff --git a/python/pyspark/sql/tests/connect/client/test_client.py b/python/pyspark/sql/tests/connect/client/test_client.py index 43094b0e7e02b..fb41c9781b97a 100644 --- a/python/pyspark/sql/tests/connect/client/test_client.py +++ b/python/pyspark/sql/tests/connect/client/test_client.py @@ -35,7 +35,7 @@ ) from pyspark.sql.connect.client.reattach import ExecutePlanResponseReattachableIterator from pyspark.sql.connect.session import SparkSession as RemoteSparkSession - from pyspark.errors import PySparkRuntimeError, RetriesExceeded + from pyspark.errors import PySparkRuntimeError import pyspark.sql.connect.proto as proto class TestPolicy(DefaultPolicy): @@ -227,35 +227,6 @@ def test_is_closed(self): client.close() self.assertTrue(client.is_closed) - def test_retry(self): - client = SparkConnectClient("sc://foo/;token=bar") - - total_sleep = 0 - - def sleep(t): - nonlocal total_sleep - total_sleep += t - - try: - for attempt in Retrying(client._retry_policies, sleep=sleep): - with attempt: - raise TestException("Retryable error", grpc.StatusCode.UNAVAILABLE) - except RetriesExceeded: - pass - - # tolerated at least 10 mins of fails - self.assertGreaterEqual(total_sleep, 600) - - def test_retry_client_unit(self): - client = SparkConnectClient("sc://foo/;token=bar") - - policyA = TestPolicy() - policyB = DefaultPolicy() - - client.set_retry_policies([policyA, policyB]) - - self.assertEqual(client.get_retry_policies(), [policyA, policyB]) - def test_channel_builder_with_session(self): dummy = str(uuid.uuid4()) chan = DefaultChannelBuilder(f"sc://foo/;session_id={dummy}") diff --git a/python/pyspark/sql/tests/connect/client/test_client_retries.py b/python/pyspark/sql/tests/connect/client/test_client_retries.py new file mode 100644 index 0000000000000..400442363b470 --- /dev/null +++ b/python/pyspark/sql/tests/connect/client/test_client_retries.py @@ -0,0 +1,225 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import unittest + +from pyspark.testing.connectutils import should_test_connect, connect_requirement_message + +if should_test_connect: + import grpc + import google.protobuf.any_pb2 as any_pb2 + import google.protobuf.duration_pb2 as duration_pb2 + from google.rpc import status_pb2 + from google.rpc import error_details_pb2 + from pyspark.sql.connect.client import SparkConnectClient + from pyspark.sql.connect.client.retries import ( + Retrying, + DefaultPolicy, + ) + from pyspark.errors import RetriesExceeded + from pyspark.sql.tests.connect.client.test_client import ( + TestPolicy, + TestException, + ) + + class SleepTimeTracker: + """Tracks sleep times in ms for testing purposes.""" + + def __init__(self): + self._times = [] + + def sleep(self, t: float): + self._times.append(int(1000 * t)) + + @property + def times(self): + return list(self._times) + + def create_test_exception_with_details( + msg: str, + code: grpc.StatusCode = grpc.StatusCode.INTERNAL, + retry_delay: int = 0, + ) -> TestException: + """Helper function for creating TestException with additional error details + like retry_delay. + """ + retry_delay_msg = duration_pb2.Duration() + retry_delay_msg.FromMilliseconds(retry_delay) + retry_info = error_details_pb2.RetryInfo() + retry_info.retry_delay.CopyFrom(retry_delay_msg) + + # Pack RetryInfo into an Any type + retry_info_any = any_pb2.Any() + retry_info_any.Pack(retry_info) + status = status_pb2.Status( + code=code.value[0], + message=msg, + details=[retry_info_any], + ) + return TestException(msg=msg, code=code, trailing_status=status) + + def get_client_policies_map(client: SparkConnectClient) -> dict: + return {type(policy): policy for policy in client.get_retry_policies()} + + +@unittest.skipIf(not should_test_connect, connect_requirement_message) +class SparkConnectClientRetriesTestCase(unittest.TestCase): + def assertListsAlmostEqual(self, first, second, places=None, msg=None, delta=None): + self.assertEqual(len(first), len(second), msg) + for i in range(len(first)): + self.assertAlmostEqual(first[i], second[i], places, msg, delta) + + def test_retry(self): + client = SparkConnectClient("sc://foo/;token=bar") + + sleep_tracker = SleepTimeTracker() + try: + for attempt in Retrying(client._retry_policies, sleep=sleep_tracker.sleep): + with attempt: + raise TestException("Retryable error", grpc.StatusCode.UNAVAILABLE) + except RetriesExceeded: + pass + + # tolerated at least 10 mins of fails + self.assertGreaterEqual(sum(sleep_tracker.times), 600) + + def test_retry_client_unit(self): + client = SparkConnectClient("sc://foo/;token=bar") + + policyA = TestPolicy() + policyB = DefaultPolicy() + + client.set_retry_policies([policyA, policyB]) + + self.assertEqual(client.get_retry_policies(), [policyA, policyB]) + + def test_default_policy_retries_retry_info(self): + client = SparkConnectClient("sc://foo/;token=bar") + policy = get_client_policies_map(client).get(DefaultPolicy) + self.assertIsNotNone(policy) + + # retry delay = 0, error code not matched by any policy. + # Testing if errors with RetryInfo are being retried by the DefaultPolicy. + retry_delay = 0 + sleep_tracker = SleepTimeTracker() + try: + for attempt in Retrying(client._retry_policies, sleep=sleep_tracker.sleep): + with attempt: + raise create_test_exception_with_details( + msg="Some error message", + code=grpc.StatusCode.UNIMPLEMENTED, + retry_delay=retry_delay, + ) + except RetriesExceeded: + pass + expected_times = [ + min(policy.max_backoff, policy.initial_backoff * policy.backoff_multiplier**i) + for i in range(policy.max_retries) + ] + self.assertListsAlmostEqual(sleep_tracker.times, expected_times, delta=policy.jitter) + + def test_retry_delay_overrides_max_backoff(self): + client = SparkConnectClient("sc://foo/;token=bar") + policy = get_client_policies_map(client).get(DefaultPolicy) + self.assertIsNotNone(policy) + + # retry delay = 5 mins. + # Testing if retry_delay overrides max_backoff. + retry_delay = 5 * 60 * 1000 + sleep_tracker = SleepTimeTracker() + # assert that retry_delay is greater than max_backoff to make sure the test is valid + self.assertGreaterEqual(retry_delay, policy.max_backoff) + try: + for attempt in Retrying(client._retry_policies, sleep=sleep_tracker.sleep): + with attempt: + raise create_test_exception_with_details( + "Some error message", + grpc.StatusCode.UNAVAILABLE, + retry_delay, + ) + except RetriesExceeded: + pass + expected_times = [retry_delay] * policy.max_retries + self.assertListsAlmostEqual(sleep_tracker.times, expected_times, delta=policy.jitter) + + def test_max_server_retry_delay(self): + client = SparkConnectClient("sc://foo/;token=bar") + policy = get_client_policies_map(client).get(DefaultPolicy) + self.assertIsNotNone(policy) + + # retry delay = 10 hours + # Testing if max_server_retry_delay limit works. + retry_delay = 10 * 60 * 60 * 1000 + sleep_tracker = SleepTimeTracker() + try: + for attempt in Retrying(client._retry_policies, sleep=sleep_tracker.sleep): + with attempt: + raise create_test_exception_with_details( + "Some error message", + grpc.StatusCode.UNAVAILABLE, + retry_delay, + ) + except RetriesExceeded: + pass + + expected_times = [policy.max_server_retry_delay] * policy.max_retries + self.assertListsAlmostEqual(sleep_tracker.times, expected_times, delta=policy.jitter) + + def test_return_to_exponential_backoff(self): + client = SparkConnectClient("sc://foo/;token=bar") + policy = get_client_policies_map(client).get(DefaultPolicy) + self.assertIsNotNone(policy) + + # Start with retry_delay = 5 mins, then set it to zero. + # Test if backoff goes back to client's exponential strategy. + initial_retry_delay = 5 * 60 * 1000 + sleep_tracker = SleepTimeTracker() + try: + for i, attempt in enumerate( + Retrying(client._retry_policies, sleep=sleep_tracker.sleep) + ): + if i < 2: + retry_delay = initial_retry_delay + elif i < 5: + retry_delay = 0 + else: + break + with attempt: + raise create_test_exception_with_details( + "Some error message", + grpc.StatusCode.UNAVAILABLE, + retry_delay, + ) + except RetriesExceeded: + pass + + expected_times = [initial_retry_delay] * 2 + [ + policy.initial_backoff * policy.backoff_multiplier**i for i in range(2, 5) + ] + self.assertListsAlmostEqual(sleep_tracker.times, expected_times, delta=policy.jitter) + + +if __name__ == "__main__": + from pyspark.sql.tests.connect.client.test_client_retries import * # noqa: F401 + + try: + import xmlrunner # type: ignore + + testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/sql/tests/connect/test_connect_retry.py b/python/pyspark/sql/tests/connect/test_connect_retry.py index f51e062479284..61ab0dcea8621 100644 --- a/python/pyspark/sql/tests/connect/test_connect_retry.py +++ b/python/pyspark/sql/tests/connect/test_connect_retry.py @@ -162,8 +162,8 @@ def test_multiple_policies_exceed(self): with attempt: self.stub(10, grpc.StatusCode.INTERNAL) - self.assertEqual(self.call_wrap["attempts"], 7) - self.assertEqual(self.call_wrap["raised"], 7) + self.assertEqual(self.call_wrap["attempts"], 3) + self.assertEqual(self.call_wrap["raised"], 3) if __name__ == "__main__": diff --git a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientRetriesSuite.scala b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientRetriesSuite.scala new file mode 100644 index 0000000000000..3408c15b73f0d --- /dev/null +++ b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientRetriesSuite.scala @@ -0,0 +1,282 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.connect.client + +import scala.concurrent.duration.FiniteDuration + +import com.google.protobuf.{Any, Duration} +import com.google.rpc +import io.grpc.{Status, StatusRuntimeException} +import io.grpc.protobuf.StatusProto +import org.scalatest.BeforeAndAfterEach +import org.scalatest.concurrent.Eventually + +import org.apache.spark.sql.connect.test.ConnectFunSuite + +class SparkConnectClientRetriesSuite + extends ConnectFunSuite + with BeforeAndAfterEach + with Eventually { + + private class DummyFn(e: => Throwable, numFails: Int = 3) { + var counter = 0 + def fn(): Int = { + if (counter < numFails) { + counter += 1 + throw e + } else { + 42 + } + } + } + + /** Tracks sleep times in milliseconds for testing purposes. */ + private class SleepTimeTracker { + private val data = scala.collection.mutable.ListBuffer[Long]() + def sleep(t: Long): Unit = data.append(t) + def times: List[Long] = data.toList + def totalSleep: Long = data.sum + } + + /** Helper function for creating a test exception with retry_delay */ + private def createTestExceptionWithDetails( + msg: String, + code: Status.Code = Status.Code.INTERNAL, + retryDelay: FiniteDuration = FiniteDuration(0, "s")): StatusRuntimeException = { + // In grpc-java, RetryDelay should be specified as seconds: Long + nanos: Int + val seconds = retryDelay.toSeconds + val nanos = (retryDelay - FiniteDuration(seconds, "s")).toNanos.toInt + val retryDelayMsg = Duration + .newBuilder() + .setSeconds(seconds) + .setNanos(nanos) + .build() + val retryInfo = rpc.RetryInfo + .newBuilder() + .setRetryDelay(retryDelayMsg) + .build() + val status = rpc.Status + .newBuilder() + .setMessage(msg) + .setCode(code.value()) + .addDetails(Any.pack(retryInfo)) + .build() + StatusProto.toStatusRuntimeException(status) + } + + /** helper function for comparing two sequences of sleep times */ + private def assertLongSequencesAlmostEqual( + first: Seq[Long], + second: Seq[Long], + delta: Long): Unit = { + assert(first.length == second.length, "Lists have different lengths.") + for ((a, b) <- first.zip(second)) { + assert(math.abs(a - b) <= delta, s"Elements $a and $b differ by more than $delta.") + } + } + + test("SPARK-44721: Retries run for a minimum period") { + // repeat test few times to avoid random flakes + for (_ <- 1 to 10) { + val st = new SleepTimeTracker() + val dummyFn = new DummyFn(new StatusRuntimeException(Status.UNAVAILABLE), numFails = 100) + val retryHandler = new GrpcRetryHandler(RetryPolicy.defaultPolicies(), st.sleep) + + assertThrows[RetriesExceeded] { + retryHandler.retry { + dummyFn.fn() + } + } + + assert(st.totalSleep >= 10 * 60 * 1000) // waited at least 10 minutes + } + } + + test("SPARK-44275: retry actually retries") { + val dummyFn = new DummyFn(new StatusRuntimeException(Status.UNAVAILABLE)) + val retryPolicies = RetryPolicy.defaultPolicies() + val retryHandler = new GrpcRetryHandler(retryPolicies, sleep = _ => {}) + val result = retryHandler.retry { dummyFn.fn() } + + assert(result == 42) + assert(dummyFn.counter == 3) + } + + test("SPARK-44275: default retryException retries only on UNAVAILABLE") { + val dummyFn = new DummyFn(new StatusRuntimeException(Status.ABORTED)) + val retryPolicies = RetryPolicy.defaultPolicies() + val retryHandler = new GrpcRetryHandler(retryPolicies, sleep = _ => {}) + + assertThrows[StatusRuntimeException] { + retryHandler.retry { dummyFn.fn() } + } + assert(dummyFn.counter == 1) + } + + test("SPARK-44275: retry uses canRetry to filter exceptions") { + val dummyFn = new DummyFn(new StatusRuntimeException(Status.UNAVAILABLE)) + val retryPolicy = RetryPolicy(canRetry = _ => false, name = "TestPolicy") + val retryHandler = new GrpcRetryHandler(retryPolicy) + + assertThrows[StatusRuntimeException] { + retryHandler.retry { dummyFn.fn() } + } + assert(dummyFn.counter == 1) + } + + test("SPARK-44275: retry does not exceed maxRetries") { + val dummyFn = new DummyFn(new StatusRuntimeException(Status.UNAVAILABLE)) + val retryPolicy = RetryPolicy(canRetry = _ => true, maxRetries = Some(1), name = "TestPolicy") + val retryHandler = new GrpcRetryHandler(retryPolicy, sleep = _ => {}) + + assertThrows[RetriesExceeded] { + retryHandler.retry { dummyFn.fn() } + } + assert(dummyFn.counter == 2) + } + + def testPolicySpecificError(maxRetries: Int, status: Status): RetryPolicy = { + RetryPolicy( + maxRetries = Some(maxRetries), + name = s"Policy for ${status.getCode}", + canRetry = { + case e: StatusRuntimeException => e.getStatus.getCode == status.getCode + case _ => false + }) + } + + test("Test multiple policies") { + val policy1 = testPolicySpecificError(maxRetries = 2, status = Status.UNAVAILABLE) + val policy2 = testPolicySpecificError(maxRetries = 4, status = Status.INTERNAL) + + // Tolerate 2 UNAVAILABLE errors and 4 INTERNAL errors + + val errors = (List.fill(2)(Status.UNAVAILABLE) ++ List.fill(4)(Status.INTERNAL)).iterator + + new GrpcRetryHandler(List(policy1, policy2), sleep = _ => {}).retry({ + val e = errors.nextOption() + if (e.isDefined) { + throw e.get.asRuntimeException() + } + }) + + assert(!errors.hasNext) + } + + test("Test multiple policies exceed") { + val policy1 = testPolicySpecificError(maxRetries = 2, status = Status.INTERNAL) + val policy2 = testPolicySpecificError(maxRetries = 4, status = Status.INTERNAL) + + val errors = List.fill(10)(Status.INTERNAL).iterator + var countAttempted = 0 + + assertThrows[RetriesExceeded]( + new GrpcRetryHandler(List(policy1, policy2), sleep = _ => {}).retry({ + countAttempted += 1 + val e = errors.nextOption() + if (e.isDefined) { + throw e.get.asRuntimeException() + } + })) + + assert(countAttempted == 3) + } + + test("DefaultPolicy retries exceptions with RetryInfo") { + // Error contains RetryInfo with retry_delay set to 0 + val dummyFn = + new DummyFn(createTestExceptionWithDetails(msg = "Some error message"), numFails = 100) + val retryPolicies = RetryPolicy.defaultPolicies() + val retryHandler = new GrpcRetryHandler(retryPolicies, sleep = _ => {}) + assertThrows[RetriesExceeded] { + retryHandler.retry { dummyFn.fn() } + } + + // Should be retried by DefaultPolicy + val policy = retryPolicies.find(_.name == "DefaultPolicy").get + assert(dummyFn.counter == policy.maxRetries.get + 1) + } + + test("retry_delay overrides maxBackoff") { + val st = new SleepTimeTracker() + val retryDelay = FiniteDuration(5, "min") + val dummyFn = new DummyFn( + createTestExceptionWithDetails(msg = "Some error message", retryDelay = retryDelay), + numFails = 100) + val retryPolicies = RetryPolicy.defaultPolicies() + val retryHandler = new GrpcRetryHandler(retryPolicies, sleep = st.sleep) + + assertThrows[RetriesExceeded] { + retryHandler.retry { dummyFn.fn() } + } + + // Should be retried by DefaultPolicy + val policy = retryPolicies.find(_.name == "DefaultPolicy").get + // sleep times are higher than maxBackoff and are equal to retryDelay + jitter + st.times.foreach(t => assert(t > policy.maxBackoff.get.toMillis + policy.jitter.toMillis)) + val expectedSleeps = List.fill(policy.maxRetries.get)(retryDelay.toMillis) + assertLongSequencesAlmostEqual(st.times, expectedSleeps, policy.jitter.toMillis) + } + + test("maxServerRetryDelay limits retry_delay") { + val st = new SleepTimeTracker() + val retryDelay = FiniteDuration(5, "d") + val dummyFn = new DummyFn( + createTestExceptionWithDetails(msg = "Some error message", retryDelay = retryDelay), + numFails = 100) + val retryPolicies = RetryPolicy.defaultPolicies() + val retryHandler = new GrpcRetryHandler(retryPolicies, sleep = st.sleep) + + assertThrows[RetriesExceeded] { + retryHandler.retry { dummyFn.fn() } + } + + // Should be retried by DefaultPolicy + val policy = retryPolicies.find(_.name == "DefaultPolicy").get + val expectedSleeps = List.fill(policy.maxRetries.get)(policy.maxServerRetryDelay.get.toMillis) + assertLongSequencesAlmostEqual(st.times, expectedSleeps, policy.jitter.toMillis) + } + + test("Policy uses to exponential backoff after retry_delay is unset") { + val st = new SleepTimeTracker() + val retryDelay = FiniteDuration(5, "min") + val retryPolicies = RetryPolicy.defaultPolicies() + val retryHandler = new GrpcRetryHandler(retryPolicies, sleep = st.sleep) + val errors = ( + List.fill(2)( + createTestExceptionWithDetails( + msg = "Some error message", + retryDelay = retryDelay)) ++ List.fill(3)( + createTestExceptionWithDetails( + msg = "Some error message", + code = Status.Code.UNAVAILABLE)) + ).iterator + + retryHandler.retry({ + if (errors.hasNext) { + throw errors.next() + } + }) + assert(!errors.hasNext) + + // Should be retried by DefaultPolicy + val policy = retryPolicies.find(_.name == "DefaultPolicy").get + val expectedSleeps = List.fill(2)(retryDelay.toMillis) ++ List.tabulate(3)(i => + policy.initialBackoff.toMillis * math.pow(policy.backoffMultiplier, i + 2).toLong) + assertLongSequencesAlmostEqual(st.times, expectedSleeps, delta = policy.jitter.toMillis) + } +} diff --git a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala index 9bb8f5889d330..a41ea344cbd4c 100644 --- a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala +++ b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala @@ -339,130 +339,6 @@ class SparkConnectClientSuite extends ConnectFunSuite with BeforeAndAfterEach { } } - private class DummyFn(e: => Throwable, numFails: Int = 3) { - var counter = 0 - def fn(): Int = { - if (counter < numFails) { - counter += 1 - throw e - } else { - 42 - } - } - } - - test("SPARK-44721: Retries run for a minimum period") { - // repeat test few times to avoid random flakes - for (_ <- 1 to 10) { - var totalSleepMs: Long = 0 - - def sleep(t: Long): Unit = { - totalSleepMs += t - } - - val dummyFn = new DummyFn(new StatusRuntimeException(Status.UNAVAILABLE), numFails = 100) - val retryHandler = new GrpcRetryHandler(RetryPolicy.defaultPolicies(), sleep) - - assertThrows[RetriesExceeded] { - retryHandler.retry { - dummyFn.fn() - } - } - - assert(totalSleepMs >= 10 * 60 * 1000) // waited at least 10 minutes - } - } - - test("SPARK-44275: retry actually retries") { - val dummyFn = new DummyFn(new StatusRuntimeException(Status.UNAVAILABLE)) - val retryPolicies = RetryPolicy.defaultPolicies() - val retryHandler = new GrpcRetryHandler(retryPolicies, sleep = _ => {}) - val result = retryHandler.retry { dummyFn.fn() } - - assert(result == 42) - assert(dummyFn.counter == 3) - } - - test("SPARK-44275: default retryException retries only on UNAVAILABLE") { - val dummyFn = new DummyFn(new StatusRuntimeException(Status.ABORTED)) - val retryPolicies = RetryPolicy.defaultPolicies() - val retryHandler = new GrpcRetryHandler(retryPolicies, sleep = _ => {}) - - assertThrows[StatusRuntimeException] { - retryHandler.retry { dummyFn.fn() } - } - assert(dummyFn.counter == 1) - } - - test("SPARK-44275: retry uses canRetry to filter exceptions") { - val dummyFn = new DummyFn(new StatusRuntimeException(Status.UNAVAILABLE)) - val retryPolicy = RetryPolicy(canRetry = _ => false, name = "TestPolicy") - val retryHandler = new GrpcRetryHandler(retryPolicy) - - assertThrows[StatusRuntimeException] { - retryHandler.retry { dummyFn.fn() } - } - assert(dummyFn.counter == 1) - } - - test("SPARK-44275: retry does not exceed maxRetries") { - val dummyFn = new DummyFn(new StatusRuntimeException(Status.UNAVAILABLE)) - val retryPolicy = RetryPolicy(canRetry = _ => true, maxRetries = Some(1), name = "TestPolicy") - val retryHandler = new GrpcRetryHandler(retryPolicy, sleep = _ => {}) - - assertThrows[RetriesExceeded] { - retryHandler.retry { dummyFn.fn() } - } - assert(dummyFn.counter == 2) - } - - def testPolicySpecificError(maxRetries: Int, status: Status): RetryPolicy = { - RetryPolicy( - maxRetries = Some(maxRetries), - name = s"Policy for ${status.getCode}", - canRetry = { - case e: StatusRuntimeException => e.getStatus.getCode == status.getCode - case _ => false - }) - } - - test("Test multiple policies") { - val policy1 = testPolicySpecificError(maxRetries = 2, status = Status.UNAVAILABLE) - val policy2 = testPolicySpecificError(maxRetries = 4, status = Status.INTERNAL) - - // Tolerate 2 UNAVAILABLE errors and 4 INTERNAL errors - - val errors = (List.fill(2)(Status.UNAVAILABLE) ++ List.fill(4)(Status.INTERNAL)).iterator - - new GrpcRetryHandler(List(policy1, policy2), sleep = _ => {}).retry({ - val e = errors.nextOption() - if (e.isDefined) { - throw e.get.asRuntimeException() - } - }) - - assert(!errors.hasNext) - } - - test("Test multiple policies exceed") { - val policy1 = testPolicySpecificError(maxRetries = 2, status = Status.INTERNAL) - val policy2 = testPolicySpecificError(maxRetries = 4, status = Status.INTERNAL) - - val errors = List.fill(10)(Status.INTERNAL).iterator - var countAttempted = 0 - - assertThrows[RetriesExceeded]( - new GrpcRetryHandler(List(policy1, policy2), sleep = _ => {}).retry({ - countAttempted += 1 - val e = errors.nextOption() - if (e.isDefined) { - throw e.get.asRuntimeException() - } - })) - - assert(countAttempted == 7) - } - test("ArtifactManager retries errors") { var attempt = 0 diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/GrpcRetryHandler.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/GrpcRetryHandler.scala index 7e0a356b9e493..0a38d18773deb 100644 --- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/GrpcRetryHandler.scala +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/GrpcRetryHandler.scala @@ -194,15 +194,17 @@ private[sql] object GrpcRetryHandler extends Logging { return } - for (policy <- policies if policy.canRetry(lastException)) { - val time = policy.nextAttempt() - + // find a policy to wait with + val matchedPolicyOpt = policies.find(_.canRetry(lastException)) + if (matchedPolicyOpt.isDefined) { + val matchedPolicy = matchedPolicyOpt.get + val time = matchedPolicy.nextAttempt(lastException) if (time.isDefined) { logWarning( log"Non-Fatal error during RPC execution: ${MDC(ERROR, lastException)}, " + log"retrying (wait=${MDC(RETRY_WAIT_TIME, time.get.toMillis)} ms, " + log"currentRetryNum=${MDC(NUM_RETRY, currentRetryNum)}, " + - log"policy=${MDC(POLICY, policy.getName)}).") + log"policy=${MDC(POLICY, matchedPolicy.getName)}).") sleep(time.get.toMillis) return } diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/RetryPolicy.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/RetryPolicy.scala index 8c8472d780dbc..5b5c4b517923e 100644 --- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/RetryPolicy.scala +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/RetryPolicy.scala @@ -18,9 +18,14 @@ package org.apache.spark.sql.connect.client import scala.concurrent.duration.{Duration, FiniteDuration} +import scala.jdk.CollectionConverters._ import scala.util.Random +import com.google.rpc.RetryInfo import io.grpc.{Status, StatusRuntimeException} +import io.grpc.protobuf.StatusProto + +import org.apache.spark.internal.Logging /** * [[RetryPolicy]] configure the retry mechanism in [[GrpcRetryHandler]] @@ -33,8 +38,27 @@ import io.grpc.{Status, StatusRuntimeException} * Maximal value of the exponential backoff (ms). * @param backoffMultiplier * Multiplicative base of the exponential backoff. + * @param jitter + * Sample a random value uniformly from the range [0, jitter] and add it to the backoff. + * @param minJitterThreshold + * Minimal value of the backoff to add random jitter. * @param canRetry * Function that determines whether a retry is to be performed in the event of an error. + * @param name + * Name of the policy. + * @param recognizeServerRetryDelay + * Per gRPC standard, the server can send error messages that contain `RetryInfo` message with + * `retry_delay` field indicating that the client should wait for at least `retry_delay` amount + * of time before retrying again, see: + * https://github.com/googleapis/googleapis/blob/master/google/rpc/error_details.proto#L91 + * + * If this flag is set to true, RetryPolicy will use `RetryInfo.retry_delay` field in the backoff + * computation. Server's `retry_delay` can override client's `maxBackoff`. + * + * This flag does not change which errors are retried, only how the backoff is computed. + * `DefaultPolicy` additionally has a rule for retrying any error that contains `RetryInfo`. + * @param maxServerRetryDelay + * Limit for the server-provided `retry_delay`. */ case class RetryPolicy( maxRetries: Option[Int] = None, @@ -44,14 +68,16 @@ case class RetryPolicy( jitter: FiniteDuration = FiniteDuration(0, "s"), minJitterThreshold: FiniteDuration = FiniteDuration(0, "s"), canRetry: Throwable => Boolean, - name: String) { + name: String, + recognizeServerRetryDelay: Boolean = false, + maxServerRetryDelay: Option[FiniteDuration] = None) { def getName: String = name def toState: RetryPolicy.RetryPolicyState = new RetryPolicy.RetryPolicyState(this) } -object RetryPolicy { +object RetryPolicy extends Logging { def defaultPolicy(): RetryPolicy = RetryPolicy( name = "DefaultPolicy", // Please synchronize changes here with Python side: @@ -65,7 +91,9 @@ object RetryPolicy { backoffMultiplier = 4.0, jitter = FiniteDuration(500, "ms"), minJitterThreshold = FiniteDuration(2, "s"), - canRetry = defaultPolicyRetryException) + canRetry = defaultPolicyRetryException, + recognizeServerRetryDelay = true, + maxServerRetryDelay = Some(FiniteDuration(10, "min"))) // list of policies to be used by this client def defaultPolicies(): Seq[RetryPolicy] = List(defaultPolicy()) @@ -77,7 +105,7 @@ object RetryPolicy { private var nextWait: Duration = policy.initialBackoff // return waiting time until next attempt, or None if has exceeded max retries - def nextAttempt(): Option[Duration] = { + def nextAttempt(e: Throwable): Option[Duration] = { if (policy.maxRetries.isDefined && numberAttempts >= policy.maxRetries.get) { return None } @@ -90,6 +118,14 @@ object RetryPolicy { nextWait = nextWait min policy.maxBackoff.get } + if (policy.recognizeServerRetryDelay) { + extractRetryDelay(e).foreach { retryDelay => + logDebug(s"The server has sent a retry delay of $retryDelay ms.") + val retryDelayLimited = retryDelay min policy.maxServerRetryDelay.getOrElse(retryDelay) + currentWait = currentWait max retryDelayLimited + } + } + if (currentWait >= policy.minJitterThreshold) { currentWait += Random.nextDouble() * policy.jitter } @@ -127,8 +163,33 @@ object RetryPolicy { if (statusCode == Status.Code.UNAVAILABLE) { return true } + + // All errors messages containing `RetryInfo` should be retried. + if (extractRetryInfo(e).isDefined) { + return true + } + false case _ => false } } + + private def extractRetryInfo(e: Throwable): Option[RetryInfo] = { + e match { + case e: StatusRuntimeException => + Option(StatusProto.fromThrowable(e)) + .flatMap(status => + status.getDetailsList.asScala + .find(_.is(classOf[RetryInfo])) + .map(_.unpack(classOf[RetryInfo]))) + case _ => None + } + } + + private def extractRetryDelay(e: Throwable): Option[FiniteDuration] = { + extractRetryInfo(e) + .flatMap(retryInfo => Option(retryInfo.getRetryDelay)) + .map(retryDelay => + FiniteDuration(retryDelay.getSeconds, "s") + FiniteDuration(retryDelay.getNanos, "ns")) + } }