Skip to content

Commit 2479c96

Browse files
authored
Merge branch 'develop' into feature/before-command-output
2 parents 2515a1d + 3257b9f commit 2479c96

File tree

3 files changed

+103
-22
lines changed

3 files changed

+103
-22
lines changed

openbb_platform/core/openbb_core/app/model/credentials.py

Lines changed: 72 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import traceback
66
import warnings
77
from pathlib import Path
8-
from typing import Annotated, Optional
8+
from typing import Annotated, ClassVar, Optional
99

1010
from openbb_core.app.constants import USER_SETTINGS_PATH
1111
from openbb_core.app.extension_loader import ExtensionLoader
@@ -39,10 +39,29 @@ class CredentialsLoader:
3939
"""Here we create the Credentials model."""
4040

4141
credentials: dict[str, list[str]] = {}
42+
env = Env()
43+
44+
@staticmethod
45+
def _normalize_credential_map(raw: dict | None) -> dict[str, object]:
46+
"""Lower-case keys and drop empty overrides so env values can win."""
47+
if not raw:
48+
return {}
49+
normalized: dict[str, object] = {}
50+
for key, value in raw.items():
51+
if not isinstance(key, str):
52+
normalized[key] = value
53+
continue
54+
normalized_key = key.strip().lower()
55+
if normalized_key in normalized and value in (None, ""):
56+
continue
57+
normalized[normalized_key] = value
58+
return normalized
4259

4360
def format_credentials(self, additional: dict) -> dict[str, tuple[object, None]]:
4461
"""Prepare credentials map to be used in the Credentials model."""
4562
formatted: dict[str, tuple[object, None]] = {}
63+
additional_data = dict(additional)
64+
4665
for c_origin, c_list in self.credentials.items():
4766
for c_name in c_list:
4867
if c_name in formatted:
@@ -51,13 +70,18 @@ def format_credentials(self, additional: dict) -> dict[str, tuple[object, None]]
5170
category=OpenBBWarning,
5271
)
5372
continue
73+
default_value = additional_data.pop(c_name, None)
5474
formatted[c_name] = (
5575
Optional[OBBSecretStr], # noqa
56-
Field(default=None, description=c_origin, alias=c_name.upper()),
76+
Field(
77+
default=default_value,
78+
description=c_origin,
79+
alias=c_name.upper(),
80+
),
5781
)
5882

59-
if additional:
60-
for key, value in additional.items():
83+
if additional_data:
84+
for key, value in additional_data.items():
6185
if key in formatted:
6286
continue
6387
formatted[key] = (
@@ -94,8 +118,6 @@ def from_providers(self) -> None:
94118

95119
def load(self) -> BaseModel:
96120
"""Load credentials from providers."""
97-
# We load providers first to give them priority choosing credential names
98-
_ = Env()
99121
self.from_providers()
100122
self.from_obbject()
101123
path = Path(USER_SETTINGS_PATH)
@@ -107,34 +129,42 @@ def load(self) -> BaseModel:
107129
if "credentials" in data:
108130
additional = data["credentials"]
109131

110-
# Collect all keys from providers to match with environment variables
132+
additional = self._normalize_credential_map(additional)
133+
111134
all_keys = [
112135
key
113136
for keys in ProviderInterface().credentials.values()
114137
if keys
115138
for key in keys
116139
]
117140

118-
for key in all_keys:
119-
if key.upper() in os.environ:
120-
value = os.environ[key.upper()]
121-
if value:
122-
additional[key] = SecretStr(value)
141+
env_credentials: dict[str, SecretStr] = {}
142+
for env_key, value in os.environ.items():
143+
if not value:
144+
continue
145+
lower_key = env_key.lower()
146+
if lower_key in all_keys or env_key.endswith("API_KEY"):
147+
canonical_key = lower_key if lower_key in all_keys else lower_key
148+
env_credentials[canonical_key] = SecretStr(value)
123149

124-
# Collect all environment variables ending with API_KEY
125-
environ_keys = [d for d in os.environ if d.endswith("API_KEY")]
150+
if env_credentials:
151+
additional.update(env_credentials)
126152

127-
for key in environ_keys:
128-
value = os.environ[key]
129-
if value:
130-
additional[key.lower()] = SecretStr(value)
153+
additional = self._normalize_credential_map(additional)
154+
155+
env_overrides = {
156+
key: additional[key]
157+
for key in env_credentials
158+
if key in additional and additional[key] not in (None, "")
159+
}
131160

132161
model = create_model(
133162
"Credentials",
134163
__config__=ConfigDict(validate_assignment=True, populate_by_name=True),
135164
**self.format_credentials(additional), # type: ignore
136165
)
137-
model.origins = self.credentials
166+
model._env_defaults = env_overrides # type: ignore # pylint: disable=W0212
167+
138168
return model
139169

140170

@@ -145,6 +175,29 @@ class Credentials(_Credentials): # type: ignore
145175
"""Credentials model used to store provider credentials."""
146176

147177
model_config = ConfigDict(extra="allow")
178+
_env_defaults: ClassVar[dict[str, object]] = getattr(
179+
_Credentials, "_env_defaults", {}
180+
)
181+
182+
@staticmethod
183+
def _is_unset(value: object) -> bool:
184+
if value is None:
185+
return True
186+
if isinstance(value, SecretStr):
187+
return not value.get_secret_value()
188+
if isinstance(value, str):
189+
return value == ""
190+
return False
191+
192+
def model_post_init(self, __context) -> None:
193+
"""Set unset credentials from environment variables."""
194+
super().model_post_init(__context)
195+
for key, secret in self._env_defaults.items():
196+
if key not in self.model_fields:
197+
continue
198+
current = getattr(self, key, None)
199+
if self._is_unset(current):
200+
setattr(self, key, secret)
148201

149202
def __repr__(self) -> str:
150203
"""Define the string representation of the credentials."""

openbb_platform/core/tests/app/model/test_credentials.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
"""Test the Credentials model."""
22

33
import importlib
4-
from unittest.mock import patch
4+
import json
5+
from unittest.mock import mock_open, patch
56

67

78
# pylint: disable=import-outside-toplevel
@@ -33,3 +34,30 @@ def test_credentials():
3334
assert creds.benzinga_api_key.get_secret_value() == "mock_benzinga_api_key"
3435
assert creds.polygon_api_key.get_secret_value() == "mock_polygon_api_key"
3536
assert creds.mock_env_api_key.get_secret_value() == "mock_env_key_value"
37+
38+
39+
def test_credentials_env_overrides_null():
40+
"""Environment variables replace stored null credentials."""
41+
fake_user_settings = json.dumps({"credentials": {"econdb_api_key": None}})
42+
with (
43+
patch(
44+
"openbb_core.app.model.credentials.ProviderInterface"
45+
) as mock_provider_interface,
46+
patch("openbb_core.app.model.credentials.ExtensionLoader") as mock_loader,
47+
patch("openbb_core.app.model.credentials.Path.exists", return_value=True),
48+
patch("builtins.open", mock_open(read_data=fake_user_settings)),
49+
patch.dict("os.environ", {"ECONDB_API_KEY": "env_econdb_key"}),
50+
):
51+
mock_provider_interface.return_value.credentials = {
52+
"econdb": ["econdb_api_key"]
53+
}
54+
mock_loader.return_value.obbject_objects = {}
55+
56+
import openbb_core.app.model.credentials as credentials_module
57+
58+
importlib.reload(credentials_module)
59+
Credentials = credentials_module.Credentials
60+
61+
creds = Credentials()
62+
63+
assert creds.econdb_api_key.get_secret_value() == "env_econdb_key"

openbb_platform/dev_install.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ def install_platform_local(_extras: bool = False):
142142
extras_args = ["-E", "all"] if _extras else []
143143

144144
subprocess.run(
145-
CMD + ["lock"],
145+
CMD + ["lock", "--regenerate"],
146146
cwd=PLATFORM_PATH,
147147
check=True,
148148
)
@@ -187,7 +187,7 @@ def install_platform_cli():
187187
CMD = [sys.executable, "-m", "poetry"]
188188

189189
subprocess.run(
190-
CMD + ["lock"],
190+
CMD + ["lock", "--regenerate"],
191191
cwd=CLI_PATH,
192192
check=True, # noqa: S603
193193
)

0 commit comments

Comments
 (0)