diff --git a/setup.py b/setup.py index 131c8f81..9210d545 100644 --- a/setup.py +++ b/setup.py @@ -29,4 +29,7 @@ "enum-compat>=0.0.1", ), options={"bdist_wheel": {"universal": True}}, + extras_require={ + "yc": ["yandexcloud", ], + } ) diff --git a/ydb/_utilities.py b/ydb/_utilities.py index 765a694c..32419b1b 100644 --- a/ydb/_utilities.py +++ b/ydb/_utilities.py @@ -13,6 +13,10 @@ interceptor = None +_grpcs_protocol = "grpcs://" +_grpc_protocol = "grpc://" + + def wrap_result_in_future(result): f = futures.Future() f.set_result(result) @@ -33,6 +37,32 @@ def x_ydb_sdk_build_info_header(): return ("x-ydb-sdk-build-info", "ydb-python-sdk/" + ydb_version.VERSION) +def is_secure_protocol(endpoint): + return endpoint.startswith("grpcs://") + + +def wrap_endpoint(endpoint): + if endpoint.startswith(_grpcs_protocol): + return endpoint[len(_grpcs_protocol) :] + if endpoint.startswith(_grpc_protocol): + return endpoint[len(_grpc_protocol) :] + return endpoint + + +def parse_connection_string(connection_string): + cs = connection_string + if not cs.startswith(_grpc_protocol) and not cs.startswith(_grpcs_protocol): + # default is grpcs + cs = _grpcs_protocol + cs + + p = six.moves.urllib.parse.urlparse(connection_string) + b = six.moves.urllib.parse.parse_qs(p.query) + database = b.get("database", []) + assert len(database) > 0 + + return p.scheme + "://" + p.netloc, database[0] + + # Decorator that ensures no exceptions are leaked from decorated async call def wrap_async_call_exceptions(f): @functools.wraps(f) diff --git a/ydb/aio/iam.py b/ydb/aio/iam.py index bf06340d..a2da48b7 100644 --- a/ydb/aio/iam.py +++ b/ydb/aio/iam.py @@ -2,11 +2,10 @@ import time import abc -import asyncio import logging import six -from ydb import issues, credentials from ydb.iam import auth +from .credentials import AbstractExpiringTokenCredentials logger = logging.getLogger(__name__) @@ -25,127 +24,20 @@ aiohttp = None -class _OneToManyValue(object): - def __init__(self): - self._value = None - self._condition = asyncio.Condition() - - async def consume(self, timeout=3): - async with self._condition: - if self._value is None: - try: - await asyncio.wait_for(self._condition.wait(), timeout=timeout) - except Exception: - return self._value - return self._value - - async def update(self, n_value): - async with self._condition: - prev_value = self._value - self._value = n_value - if prev_value is None: - self._condition.notify_all() - - -class _AtMostOneExecution(object): - def __init__(self): - self._can_schedule = True - self._lock = asyncio.Lock() # Lock to guarantee only one execution - - async def _wrapped_execution(self, callback): - await self._lock.acquire() - try: - res = callback() - if asyncio.iscoroutine(res): - await res - except Exception: - pass - - finally: - self._lock.release() - self._can_schedule = True - - def submit(self, callback): - if self._can_schedule: - self._can_schedule = False - asyncio.ensure_future(self._wrapped_execution(callback)) - - @six.add_metaclass(abc.ABCMeta) -class IamTokenCredentials(auth.IamTokenCredentials): - def __init__(self): - super(IamTokenCredentials, self).__init__() - self._tp = _AtMostOneExecution() - self._iam_token = _OneToManyValue() - - @abc.abstractmethod - async def _get_iam_token(self): - pass - - async def _refresh(self): - current_time = time.time() - self._log_refresh_start(current_time) - - try: - auth_metadata = await self._get_iam_token() - await self._iam_token.update(auth_metadata["access_token"]) - self.update_expiration_info(auth_metadata) - self.logger.info( - "Token refresh successful. current_time %s, refresh_in %s", - current_time, - self._refresh_in, - ) - - except (KeyboardInterrupt, SystemExit): - return - - except Exception as e: - self.last_error = str(e) - await asyncio.sleep(1) - self._tp.submit(self._refresh) - - async def iam_token(self): - current_time = time.time() - if current_time > self._refresh_in: - self._tp.submit(self._refresh) - - iam_token = await self._iam_token.consume(timeout=3) - if iam_token is None: - if self.last_error is None: - raise issues.ConnectionError( - "%s: timeout occurred while waiting for token.\n%s" - % self.__class__.__name__, - self.extra_error_message, - ) - raise issues.ConnectionError( - "%s: %s.\n%s" - % (self.__class__.__name__, self.last_error, self.extra_error_message) - ) - return iam_token - - async def auth_metadata(self): - return [(credentials.YDB_AUTH_TICKET_HEADER, await self.iam_token())] - - -@six.add_metaclass(abc.ABCMeta) -class TokenServiceCredentials(IamTokenCredentials): +class TokenServiceCredentials(AbstractExpiringTokenCredentials): def __init__(self, iam_endpoint=None, iam_channel_credentials=None): super(TokenServiceCredentials, self).__init__() + assert ( + iam_token_service_pb2_grpc is not None + ), "run pip install==ydb[yc] to use service account credentials" + self._get_token_request_timeout = 10 self._iam_endpoint = ( "iam.api.cloud.yandex.net:443" if iam_endpoint is None else iam_endpoint ) self._iam_channel_credentials = ( {} if iam_channel_credentials is None else iam_channel_credentials ) - self._get_token_request_timeout = 10 - if ( - iam_token_service_pb2_grpc is None - or jwt is None - or iam_token_service_pb2 is None - ): - raise RuntimeError( - "Install jwt & yandex python cloud library to use service account credentials provider" - ) def _channel_factory(self): return grpc.aio.secure_channel( @@ -157,7 +49,7 @@ def _channel_factory(self): def _get_token_request(self): pass - async def _get_iam_token(self): + async def _make_token_request(self): async with self._channel_factory() as channel: stub = iam_token_service_pb2_grpc.IamTokenServiceStub(channel) response = await stub.Create( @@ -209,20 +101,19 @@ def _get_token_request(self): ) -class MetadataUrlCredentials(IamTokenCredentials): +class MetadataUrlCredentials(AbstractExpiringTokenCredentials): def __init__(self, metadata_url=None): super(MetadataUrlCredentials, self).__init__() - if aiohttp is None: - raise RuntimeError( - "Install aiohttp library to use metadata credentials provider" - ) + assert ( + aiohttp is not None + ), "Install aiohttp library to use metadata credentials provider" self._metadata_url = ( auth.DEFAULT_METADATA_URL if metadata_url is None else metadata_url ) self._tp.submit(self._refresh) self.extra_error_message = "Check that metadata service configured properly and application deployed in VM or function at Yandex.Cloud." - async def _get_iam_token(self): + async def _make_token_request(self): timeout = aiohttp.ClientTimeout(total=2) async with aiohttp.ClientSession(timeout=timeout) as session: async with session.get( diff --git a/ydb/connection.py b/ydb/connection.py index ed86de44..95db084a 100644 --- a/ydb/connection.py +++ b/ydb/connection.py @@ -138,7 +138,8 @@ def _construct_metadata(driver_config, settings): if driver_config.database is not None: metadata.append((YDB_DATABASE_HEADER, driver_config.database)) - if driver_config.credentials is not None: + need_rpc_auth = getattr(settings, "need_rpc_auth", True) + if driver_config.credentials is not None and need_rpc_auth: metadata.extend(driver_config.credentials.auth_metadata()) if settings is not None: diff --git a/ydb/credentials.py b/ydb/credentials.py index a291258d..96247ef3 100644 --- a/ydb/credentials.py +++ b/ydb/credentials.py @@ -1,9 +1,18 @@ # -*- coding: utf-8 -*- import abc import six -from . import tracing +from . import tracing, issues, connection +from . import settings as settings_impl +import threading +from concurrent import futures +import logging +import time +from ydb.public.api.protos import ydb_auth_pb2 +from ydb.public.api.grpc import ydb_auth_v1_pb2_grpc + YDB_AUTH_TICKET_HEADER = "x-ydb-auth-ticket" +logger = logging.getLogger(__name__) @six.add_metaclass(abc.ABCMeta) @@ -26,6 +35,178 @@ def auth_metadata(self): pass +class OneToManyValue(object): + def __init__(self): + self._value = None + self._condition = threading.Condition() + + def consume(self, timeout=3): + with self._condition: + if self._value is None: + self._condition.wait(timeout=timeout) + return self._value + + def update(self, n_value): + with self._condition: + prev_value = self._value + self._value = n_value + if prev_value is None: + self._condition.notify_all() + + +class AtMostOneExecution(object): + def __init__(self): + self._can_schedule = True + self._lock = threading.Lock() + self._tp = futures.ThreadPoolExecutor(1) + + def wrapped_execution(self, callback): + try: + callback() + except Exception: + pass + + finally: + self.cleanup() + + def submit(self, callback): + with self._lock: + if self._can_schedule: + self._tp.submit(self.wrapped_execution, callback) + self._can_schedule = False + + def cleanup(self): + with self._lock: + self._can_schedule = True + + +@six.add_metaclass(abc.ABCMeta) +class AbstractExpiringTokenCredentials(Credentials): + def __init__(self, tracer=None): + super(AbstractExpiringTokenCredentials, self).__init__(tracer) + self._expires_in = 0 + self._refresh_in = 0 + self._hour = 60 * 60 + self._cached_token = OneToManyValue() + self._tp = AtMostOneExecution() + self.logger = logger.getChild(self.__class__.__name__) + self.last_error = None + self.extra_error_message = "" + + @abc.abstractmethod + def _make_token_request(self): + pass + + def _log_refresh_start(self, current_time): + self.logger.debug("Start refresh token from metadata") + if current_time > self._refresh_in: + self.logger.info( + "Cached token reached refresh_in deadline, current time %s, deadline %s", + current_time, + self._refresh_in, + ) + + if current_time > self._expires_in and self._expires_in > 0: + self.logger.error( + "Cached token reached expires_in deadline, current time %s, deadline %s", + current_time, + self._expires_in, + ) + + def _update_expiration_info(self, auth_metadata): + self._expires_in = time.time() + min( + self._hour, auth_metadata["expires_in"] / 2 + ) + self._refresh_in = time.time() + min( + self._hour / 2, auth_metadata["expires_in"] / 4 + ) + + def _refresh(self): + current_time = time.time() + self._log_refresh_start(current_time) + try: + token_response = self._make_token_request() + self._cached_token.update(token_response["access_token"]) + self._update_expiration_info(token_response) + self.logger.info( + "Token refresh successful. current_time %s, refresh_in %s", + current_time, + self._refresh_in, + ) + + except (KeyboardInterrupt, SystemExit): + return + + except Exception as e: + self.last_error = str(e) + time.sleep(1) + self._tp.submit(self._refresh) + + @property + @tracing.with_trace() + def token(self): + current_time = time.time() + if current_time > self._refresh_in: + tracing.trace(self.tracer, {"refresh": True}) + self._tp.submit(self._refresh) + cached_token = self._cached_token.consume(timeout=3) + tracing.trace(self.tracer, {"consumed": True}) + if cached_token is None: + if self.last_error is None: + raise issues.ConnectionError( + "%s: timeout occurred while waiting for token.\n%s" + % ( + self.__class__.__name__, + self.extra_error_message, + ) + ) + raise issues.ConnectionError( + "%s: %s.\n%s" + % (self.__class__.__name__, self.last_error, self.extra_error_message) + ) + return cached_token + + def auth_metadata(self): + return [(YDB_AUTH_TICKET_HEADER, self.token)] + + +def _wrap_static_credentials_response(rpc_state, response): + issues._process_response(response.operation) + result = ydb_auth_pb2.LoginResult() + response.operation.result.Unpack(result) + return result + + +class StaticCredentials(AbstractExpiringTokenCredentials): + def __init__(self, driver_config, user, password="", tracer=None): + super(StaticCredentials, self).__init__(tracer) + self.driver_config = driver_config + self.user = user + self.password = password + self.request_timeout = 10 + + def _make_token_request(self): + conn = connection.Connection.ready_factory( + self.driver_config.endpoint, self.driver_config + ) + assert conn is not None, ( + "Failed to establish connection in to %s" % self.driver_config.endpoint + ) + try: + result = conn( + ydb_auth_pb2.LoginRequest(user=self.user, password=self.password), + ydb_auth_v1_pb2_grpc.AuthServiceStub, + "Login", + _wrap_static_credentials_response, + settings_impl.BaseRequestSettings() + .with_timeout(self.request_timeout) + .with_need_rpc_auth(False), + ) + finally: + conn.close() + return {"expires_in": 30 * 60, "access_token": result.token} + + class AnonymousCredentials(Credentials): @staticmethod def auth_metadata(): diff --git a/ydb/driver.py b/ydb/driver.py index da300373..9b3fa99c 100644 --- a/ydb/driver.py +++ b/ydb/driver.py @@ -4,41 +4,13 @@ import six import os import grpc +from . import _utilities if six.PY2: Any = None else: from typing import Any # noqa -_grpcs_protocol = "grpcs://" -_grpc_protocol = "grpc://" - - -def is_secure_protocol(endpoint): - return endpoint.startswith("grpcs://") - - -def wrap_endpoint(endpoint): - if endpoint.startswith(_grpcs_protocol): - return endpoint[len(_grpcs_protocol) :] - if endpoint.startswith(_grpc_protocol): - return endpoint[len(_grpc_protocol) :] - return endpoint - - -def parse_connection_string(connection_string): - cs = connection_string - if not cs.startswith(_grpc_protocol) and not cs.startswith(_grpcs_protocol): - # default is grpcs - cs = _grpcs_protocol + cs - - p = six.moves.urllib.parse.urlparse(connection_string) - b = six.moves.urllib.parse.parse_qs(p.query) - database = b.get("database", []) - assert len(database) > 0 - - return p.scheme + "://" + p.netloc, database[0] - class RPCCompression: """Indicates the compression method to be used for an RPC.""" @@ -152,11 +124,11 @@ def __init__( self.database = database self.ca_cert = ca_cert self.channel_options = channel_options - self.secure_channel = is_secure_protocol(endpoint) - self.endpoint = wrap_endpoint(self.endpoint) + self.secure_channel = _utilities.is_secure_protocol(endpoint) + self.endpoint = _utilities.wrap_endpoint(self.endpoint) self.endpoints = [] if endpoints is not None: - self.endpoints = [wrap_endpoint(endp) for endp in endpoints] + self.endpoints = [_utilities.wrap_endpoint(endp) for endp in endpoints] if auth_token is not None: credentials = credentials_impl.AuthTokenCredentials(auth_token) self.credentials = credentials @@ -192,7 +164,7 @@ def default_from_endpoint_and_database( def default_from_connection_string( cls, connection_string, root_certificates=None, credentials=None, **kwargs ): - endpoint, database = parse_connection_string(connection_string) + endpoint, database = _utilities.parse_connection_string(connection_string) return cls( endpoint, database, diff --git a/ydb/iam/auth.py b/ydb/iam/auth.py index 8a21346d..3bd9ddbf 100644 --- a/ydb/iam/auth.py +++ b/ydb/iam/auth.py @@ -6,13 +6,7 @@ import six from datetime import datetime import json -import threading -from concurrent import futures import os -import logging -from ydb import issues - -logger = logging.getLogger(__name__) try: from yandex.cloud.iam.v1 import iam_token_service_pb2_grpc @@ -51,158 +45,20 @@ def get_jwt(account_id, access_key_id, private_key, jwt_expiration_timeout): ) -class OneToManyValue(object): - def __init__(self): - self._value = None - self._condition = threading.Condition() - - def consume(self, timeout=3): - with self._condition: - if self._value is None: - self._condition.wait(timeout=timeout) - return self._value - - def update(self, n_value): - with self._condition: - prev_value = self._value - self._value = n_value - if prev_value is None: - self._condition.notify_all() - - -class AtMostOneExecution(object): - def __init__(self): - self._can_schedule = True - self._lock = threading.Lock() - self._tp = futures.ThreadPoolExecutor(1) - - def wrapped_execution(self, callback): - try: - callback() - except Exception: - pass - - finally: - self.cleanup() - - def submit(self, callback): - with self._lock: - if self._can_schedule: - self._tp.submit(self.wrapped_execution, callback) - self._can_schedule = False - - def cleanup(self): - with self._lock: - self._can_schedule = True - - @six.add_metaclass(abc.ABCMeta) -class IamTokenCredentials(credentials.Credentials): - def __init__(self, tracer=None): - super(IamTokenCredentials, self).__init__(tracer) - self._expires_in = 0 - self._refresh_in = 0 - self._hour = 60 * 60 - self._iam_token = OneToManyValue() - self._tp = AtMostOneExecution() - self.logger = logger.getChild(self.__class__.__name__) - self.last_error = None - self.extra_error_message = "" - - @abc.abstractmethod - def _get_iam_token(self): - pass - - def _log_refresh_start(self, current_time): - self.logger.debug("Start refresh token from metadata") - if current_time > self._refresh_in: - self.logger.info( - "Cached token reached refresh_in deadline, current time %s, deadline %s", - current_time, - self._refresh_in, - ) - - if current_time > self._expires_in and self._expires_in > 0: - self.logger.error( - "Cached token reached expires_in deadline, current time %s, deadline %s", - current_time, - self._expires_in, - ) - - def _update_expiration_info(self, auth_metadata): - self._expires_in = time.time() + min( - self._hour, auth_metadata["expires_in"] / 2 - ) - self._refresh_in = time.time() + min( - self._hour / 2, auth_metadata["expires_in"] / 4 - ) - - def _refresh(self): - current_time = time.time() - self._log_refresh_start(current_time) - try: - auth_metadata = self._get_iam_token() - self._iam_token.update(auth_metadata["access_token"]) - self._update_expiration_info(auth_metadata) - self.logger.info( - "Token refresh successful. current_time %s, refresh_in %s", - current_time, - self._refresh_in, - ) - - except (KeyboardInterrupt, SystemExit): - return - - except Exception as e: - self.last_error = str(e) - time.sleep(1) - self._tp.submit(self._refresh) - - @property - @tracing.with_trace() - def iam_token(self): - current_time = time.time() - if current_time > self._refresh_in: - tracing.trace(self.tracer, {"refresh": True}) - self._tp.submit(self._refresh) - iam_token = self._iam_token.consume(timeout=3) - tracing.trace(self.tracer, {"consumed": True}) - if iam_token is None: - if self.last_error is None: - raise issues.ConnectionError( - "%s: timeout occurred while waiting for token.\n%s" - % self.__class__.__name__, - self.extra_error_message, - ) - raise issues.ConnectionError( - "%s: %s.\n%s" - % (self.__class__.__name__, self.last_error, self.extra_error_message) - ) - return iam_token - - def auth_metadata(self): - return [(credentials.YDB_AUTH_TICKET_HEADER, self.iam_token)] - - -@six.add_metaclass(abc.ABCMeta) -class TokenServiceCredentials(IamTokenCredentials): +class TokenServiceCredentials(credentials.AbstractExpiringTokenCredentials): def __init__(self, iam_endpoint=None, iam_channel_credentials=None, tracer=None): super(TokenServiceCredentials, self).__init__(tracer) + assert ( + iam_token_service_pb2_grpc is not None + ), "run pip install==ydb[yc] to use service account credentials" + self._get_token_request_timeout = 10 self._iam_endpoint = ( "iam.api.cloud.yandex.net:443" if iam_endpoint is None else iam_endpoint ) self._iam_channel_credentials = ( {} if iam_channel_credentials is None else iam_channel_credentials ) - self._get_token_request_timeout = 10 - if ( - iam_token_service_pb2_grpc is None - or jwt is None - or iam_token_service_pb2 is None - ): - raise RuntimeError( - "Install jwt & yandex python cloud library to use service account credentials provider" - ) def _channel_factory(self): return grpc.secure_channel( @@ -215,7 +71,7 @@ def _get_token_request(self): pass @tracing.with_trace() - def _get_iam_token(self): + def _make_token_request(self): with self._channel_factory() as channel: tracing.trace(self.tracer, {"iam_token.from_service": True}) stub = iam_token_service_pb2_grpc.IamTokenServiceStub(channel) @@ -296,26 +152,24 @@ def _get_token_request(self): ) -class MetadataUrlCredentials(IamTokenCredentials): +class MetadataUrlCredentials(credentials.AbstractExpiringTokenCredentials): def __init__(self, metadata_url=None, tracer=None): """ - :param metadata_url: Metadata url :param ydb.Tracer tracer: ydb tracer """ super(MetadataUrlCredentials, self).__init__(tracer) - if requests is None: - raise RuntimeError( - "Install requests library to use metadata credentials provider" - ) + assert ( + requests is not None + ), "Install requests library to use metadata credentials provider" + self.extra_error_message = "Check that metadata service configured properly since we failed to fetch it from metadata_url." self._metadata_url = ( DEFAULT_METADATA_URL if metadata_url is None else metadata_url ) self._tp.submit(self._refresh) - self.extra_error_message = "Check that metadata service configured properly and application deployed in VM or function at Yandex.Cloud." @tracing.with_trace() - def _get_iam_token(self): + def _make_token_request(self): response = requests.get( self._metadata_url, headers={"Metadata-Flavor": "Google"}, timeout=3 ) diff --git a/ydb/settings.py b/ydb/settings.py index f55b1abb..6739a46f 100644 --- a/ydb/settings.py +++ b/ydb/settings.py @@ -11,6 +11,7 @@ class BaseRequestSettings(object): "tracer", "compression", "headers", + "need_rpc_auth", ) def __init__(self): @@ -23,6 +24,7 @@ def __init__(self): self.cancel_after = None self.operation_timeout = None self.compression = None + self.need_rpc_auth = True self.headers = [] def make_copy(self): @@ -34,6 +36,7 @@ def make_copy(self): .with_cancel_after(self.cancel_after) .with_operation_timeout(self.operation_timeout) .with_compression(self.compression) + .with_need_rpc_auth(self.need_rpc_auth) ) def with_compression(self, compression): @@ -45,6 +48,10 @@ def with_compression(self, compression): self.compression = compression return self + def with_need_rpc_auth(self, need_rpc_auth): + self.need_rpc_auth = need_rpc_auth + return self + def with_header(self, key, value): """ Adds a key-value pair to the request headers.