Skip to content

Commit 257714d

Browse files
feat(python): Add Python-side caching for credentials and provider auto-initialization (#23736)
1 parent f576f7e commit 257714d

File tree

4 files changed

+402
-25
lines changed

4 files changed

+402
-25
lines changed

py-polars/polars/io/cloud/_utils.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,53 @@
11
from __future__ import annotations
22

33
from pathlib import Path
4-
from typing import Any
4+
from typing import Any, Generic, TypeVar
55

66
from polars._utils.various import is_path_or_str_sequence
77
from polars.io.partition import PartitionMaxSize
88

9+
T = TypeVar("T")
10+
11+
12+
class NoPickleOption(Generic[T]):
13+
"""
14+
Wrapper that does not pickle the wrapped value.
15+
16+
This wrapper will unpickle to contain a None. Used for cached values.
17+
"""
18+
19+
def __init__(self, opt_value: T | None = None) -> None:
20+
self._opt_value = opt_value
21+
22+
def get(self) -> T | None:
23+
return self._opt_value
24+
25+
def set(self, value: T | None) -> None:
26+
self._opt_value = value
27+
28+
def __getstate__(self) -> tuple[()]:
29+
# Needs to return not-None for `__setstate__()` to be called
30+
return ()
31+
32+
def __setstate__(self, _state: tuple[()]) -> None:
33+
NoPickleOption.__init__(self)
34+
35+
36+
class ZeroHashWrap(Generic[T]):
37+
"""Wrapper that always hashes to 0 and always returns True for __eq__."""
38+
39+
def __init__(self, value: T) -> None:
40+
self._value = value
41+
42+
def get(self) -> T:
43+
return self._value
44+
45+
def __eq__(self, _other: object) -> bool:
46+
return True
47+
48+
def __hash__(self) -> int:
49+
return 0
50+
951

1052
def _first_scan_path(
1153
source: Any,

py-polars/polars/io/cloud/credential_provider/_builder.py

Lines changed: 82 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,30 @@
11
from __future__ import annotations
22

33
import abc
4-
from typing import TYPE_CHECKING, Any, Literal
4+
import os
5+
from functools import lru_cache
6+
from typing import TYPE_CHECKING, Any, Callable, Literal, Union
57

68
import polars._utils.logging
79
from polars._utils.logging import eprint, verbose
810
from polars._utils.unstable import issue_unstable_warning
11+
from polars.io.cloud._utils import NoPickleOption, ZeroHashWrap
912
from polars.io.cloud.credential_provider._providers import (
1013
CredentialProvider,
1114
CredentialProviderAWS,
1215
CredentialProviderAzure,
16+
CredentialProviderFunction,
1317
CredentialProviderGCP,
18+
UserProvidedGCPToken,
1419
)
1520

1621
if TYPE_CHECKING:
17-
from polars.io.cloud.credential_provider._providers import (
18-
CredentialProviderFunction,
19-
CredentialProviderFunctionReturn,
20-
)
22+
import sys
23+
24+
if sys.version_info >= (3, 10):
25+
from typing import TypeAlias
26+
else:
27+
from typing_extensions import TypeAlias
2128

2229
# https://docs.rs/object_store/latest/object_store/enum.ClientConfigKey.html
2330
OBJECT_STORE_CLIENT_OPTIONS: frozenset[str] = frozenset(
@@ -42,6 +49,10 @@
4249
]
4350
)
4451

52+
CredentialProviderBuilderReturn: TypeAlias = Union[
53+
CredentialProvider, CredentialProviderFunction, None
54+
]
55+
4556

4657
class CredentialProviderBuilder:
4758
"""
@@ -70,7 +81,7 @@ def __init__(
7081
# Note: The rust-side expects this exact function name.
7182
def build_credential_provider(
7283
self,
73-
) -> CredentialProvider | CredentialProviderFunction | None:
84+
) -> CredentialProviderBuilderReturn:
7485
"""Instantiate a credential provider from configuration."""
7586
verbose = polars._utils.logging.verbose()
7687

@@ -154,38 +165,91 @@ def provider_repr(self) -> str:
154165
return repr(self.credential_provider)
155166

156167

168+
AUTO_INIT_LRU_CACHE: (
169+
Callable[
170+
[bytes, ZeroHashWrap[Callable[[], CredentialProviderBuilderReturn]]],
171+
CredentialProviderBuilderReturn,
172+
]
173+
| None
174+
) = None
175+
176+
177+
def _auto_init_with_cache(
178+
get_cache_key_func: Callable[[], bytes],
179+
build_provider_func: ZeroHashWrap[Callable[[], CredentialProviderBuilderReturn]],
180+
) -> CredentialProviderBuilderReturn:
181+
global AUTO_INIT_LRU_CACHE
182+
183+
if (
184+
maxsize := int(os.getenv("POLARS_CREDENTIAL_PROVIDER_BUILDER_CACHE_SIZE", 8))
185+
) <= 0:
186+
AUTO_INIT_LRU_CACHE = None
187+
188+
return build_provider_func.get()()
189+
190+
if AUTO_INIT_LRU_CACHE is None:
191+
if verbose():
192+
eprint(f"Create credential provider AutoInit LRU cache ({maxsize = })")
193+
194+
@lru_cache(maxsize=maxsize)
195+
def cache(
196+
_cache_key: bytes,
197+
build_provider_func: ZeroHashWrap[
198+
Callable[[], CredentialProviderBuilderReturn]
199+
],
200+
) -> CredentialProviderBuilderReturn:
201+
return build_provider_func.get()()
202+
203+
AUTO_INIT_LRU_CACHE = cache
204+
205+
return AUTO_INIT_LRU_CACHE(
206+
get_cache_key_func(),
207+
build_provider_func,
208+
)
209+
210+
157211
# Represents an automatic initialization configuration. This is created for
158212
# credential_provider="auto".
159213
class AutoInit(CredentialProviderBuilderImpl):
160214
def __init__(self, cls: Any, **kw: Any) -> None:
161215
self.cls = cls
162216
self.kw = kw
217+
self._cache_key: NoPickleOption[bytes] = NoPickleOption()
163218

164-
def __call__(self) -> Any:
219+
def __call__(self) -> CredentialProviderFunction | None:
165220
# This is used for credential_provider="auto", which allows for
166221
# ImportErrors.
167222
try:
168-
return self.cls(**self.kw)
223+
return _auto_init_with_cache(
224+
self.get_or_init_cache_key,
225+
ZeroHashWrap(lambda: self.cls(**self.kw)),
226+
)
169227
except ImportError as e:
170228
if verbose():
171229
eprint(f"failed to auto-initialize {self.provider_repr}: {e!r}")
172230

173231
return None
174232

175-
@property
176-
def provider_repr(self) -> str:
177-
return self.cls.__name__
233+
def get_or_init_cache_key(self) -> bytes:
234+
cache_key = self._cache_key.get()
178235

236+
if cache_key is None:
237+
import hashlib
238+
import pickle
179239

180-
class UserProvidedGCPToken(CredentialProvider):
181-
"""User-provided GCP token in storage_options."""
240+
hash = hashlib.sha256(pickle.dumps(self))
241+
self._cache_key.set(hash.digest())
242+
cache_key = self._cache_key.get()
243+
assert isinstance(cache_key, bytes)
182244

183-
def __init__(self, token: str) -> None:
184-
self.token = token
245+
if verbose():
246+
eprint(f"{self!r}: AutoInit cache key: {hash.hexdigest()}")
247+
248+
return cache_key
185249

186-
def __call__(self) -> CredentialProviderFunctionReturn:
187-
"""Fetches the credentials."""
188-
return {"bearer_token": self.token}, None
250+
@property
251+
def provider_repr(self) -> str:
252+
return self.cls.__name__
189253

190254

191255
def _init_credential_provider_builder(

py-polars/polars/io/cloud/credential_provider/_providers.py

Lines changed: 77 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,18 @@
88
import sys
99
import zoneinfo
1010
from datetime import datetime
11-
from typing import TYPE_CHECKING, Any, Callable, Optional, TypedDict, Union
11+
from typing import (
12+
TYPE_CHECKING,
13+
Any,
14+
Callable,
15+
Optional,
16+
TypedDict,
17+
Union,
18+
)
1219

1320
import polars._utils.logging
1421
from polars._utils.logging import eprint, verbose
22+
from polars.io.cloud._utils import NoPickleOption
1523

1624
if TYPE_CHECKING:
1725
if sys.version_info >= (3, 10):
@@ -56,9 +64,53 @@ class CredentialProvider(abc.ABC):
5664
at any point without it being considered a breaking change.
5765
"""
5866

59-
@abc.abstractmethod
67+
def __init__(self) -> None:
68+
self._cached_credentials: NoPickleOption[CredentialProviderFunctionReturn] = (
69+
NoPickleOption()
70+
)
71+
self._has_logged_use_cache = False
72+
73+
if verbose():
74+
eprint(
75+
f"[{type(self).__name__} @ {hex(id(self))}]: CredentialProvider.__init__()"
76+
)
77+
6078
def __call__(self) -> CredentialProviderFunctionReturn:
6179
"""Fetches the credentials."""
80+
if os.getenv("POLARS_DISABLE_PYTHON_CREDENTIAL_CACHING") == "1":
81+
self._cached_credentials.set(None)
82+
return self.retrieve_credentials_impl()
83+
84+
if not isinstance(getattr(self, "_cached_credentials", None), NoPickleOption):
85+
msg = (
86+
f"[{type(self).__name__} @ {hex(id(self))}]: `_cached_credentials` attribute "
87+
"not found. This can happen if a subclass forgets to call "
88+
f"super().__init__() ({type(self) = })"
89+
)
90+
raise AttributeError(msg) # noqa: TRY004
91+
92+
cached = self._cached_credentials.get()
93+
94+
if cached is None or (
95+
(expiry := cached[1]) is not None
96+
and expiry <= int(datetime.now().timestamp())
97+
):
98+
self._cached_credentials.set(self.retrieve_credentials_impl())
99+
self._has_logged_use_cache = False
100+
cached = self._cached_credentials.get()
101+
assert cached is not None
102+
103+
elif verbose() and not self._has_logged_use_cache:
104+
expiry = cached[1]
105+
eprint(
106+
f"[{type(self).__name__} @ {hex(id(self))}]: Using cached credentials ({expiry = })"
107+
)
108+
self._has_logged_use_cache = True
109+
110+
return cached
111+
112+
@abc.abstractmethod
113+
def retrieve_credentials_impl(self) -> CredentialProviderFunctionReturn: ...
62114

63115

64116
class CredentialProviderAWS(CredentialProvider):
@@ -95,14 +147,16 @@ def __init__(
95147
msg = "`CredentialProviderAWS` functionality is considered unstable"
96148
issue_unstable_warning(msg)
97149

150+
super().__init__()
151+
98152
self._ensure_module_availability()
99153
self.profile_name = profile_name
100154
self.region_name = region_name
101155
self.assume_role = assume_role
102156
self._auto_init_unhandled_key = _auto_init_unhandled_key
103157
self._storage_options_has_endpoint_url = _storage_options_has_endpoint_url
104158

105-
def __call__(self) -> CredentialProviderFunctionReturn:
159+
def retrieve_credentials_impl(self) -> CredentialProviderFunctionReturn:
106160
"""Fetch the credentials for the configured profile name."""
107161
assert not self._auto_init_unhandled_key
108162

@@ -253,6 +307,8 @@ def __init__(
253307
msg = "`CredentialProviderAzure` functionality is considered unstable"
254308
issue_unstable_warning(msg)
255309

310+
super().__init__()
311+
256312
self.account_name = _storage_account
257313
self.scopes = (
258314
scopes if scopes is not None else ["https://storage.azure.com/.default"]
@@ -283,7 +339,7 @@ def __init__(
283339
f"{self.scopes = } "
284340
)
285341

286-
def __call__(self) -> CredentialProviderFunctionReturn:
342+
def retrieve_credentials_impl(self) -> CredentialProviderFunctionReturn:
287343
"""Fetch the credentials."""
288344
if (
289345
v := self._try_get_azure_storage_account_credential_if_permitted()
@@ -438,6 +494,8 @@ def __init__(
438494
msg = "`CredentialProviderGCP` functionality is considered unstable"
439495
issue_unstable_warning(msg)
440496

497+
super().__init__()
498+
441499
self._ensure_module_availability()
442500

443501
import google.auth
@@ -462,7 +520,7 @@ def __init__(
462520
)
463521
self.creds = creds
464522

465-
def __call__(self) -> CredentialProviderFunctionReturn:
523+
def retrieve_credentials_impl(self) -> CredentialProviderFunctionReturn:
466524
"""Fetch the credentials."""
467525
import google.auth.transport.requests
468526

@@ -487,6 +545,20 @@ def _ensure_module_availability(cls) -> None:
487545
raise ImportError(msg)
488546

489547

548+
class UserProvidedGCPToken(CredentialProvider):
549+
"""User-provided GCP token in storage_options."""
550+
551+
def __init__(self, token: str) -> None:
552+
self.token = token
553+
554+
def __call__(self) -> CredentialProviderFunctionReturn:
555+
return self.retrieve_credentials_impl()
556+
557+
def retrieve_credentials_impl(self) -> CredentialProviderFunctionReturn:
558+
"""Fetches the credentials."""
559+
return {"bearer_token": self.token}, None
560+
561+
490562
def _get_credentials_from_provider_expiry_aware(
491563
credential_provider: CredentialProviderFunction,
492564
) -> dict[str, str] | None:

0 commit comments

Comments
 (0)