55import traceback
66import warnings
77from pathlib import Path
8- from typing import Annotated , Optional
8+ from typing import Annotated , ClassVar , Optional
99
1010from openbb_core .app .constants import USER_SETTINGS_PATH
1111from openbb_core .app .extension_loader import ExtensionLoader
@@ -39,10 +39,29 @@ class CredentialsLoader:
3939 """Here we create the Credentials model."""
4040
4141 credentials : dict [str , list [str ]] = {}
42+ env = Env ()
43+
44+ @staticmethod
45+ def _normalize_credential_map (raw : dict | None ) -> dict [str , object ]:
46+ """Lower-case keys and drop empty overrides so env values can win."""
47+ if not raw :
48+ return {}
49+ normalized : dict [str , object ] = {}
50+ for key , value in raw .items ():
51+ if not isinstance (key , str ):
52+ normalized [key ] = value
53+ continue
54+ normalized_key = key .strip ().lower ()
55+ if normalized_key in normalized and value in (None , "" ):
56+ continue
57+ normalized [normalized_key ] = value
58+ return normalized
4259
4360 def format_credentials (self , additional : dict ) -> dict [str , tuple [object , None ]]:
4461 """Prepare credentials map to be used in the Credentials model."""
4562 formatted : dict [str , tuple [object , None ]] = {}
63+ additional_data = dict (additional )
64+
4665 for c_origin , c_list in self .credentials .items ():
4766 for c_name in c_list :
4867 if c_name in formatted :
@@ -51,13 +70,18 @@ def format_credentials(self, additional: dict) -> dict[str, tuple[object, None]]
5170 category = OpenBBWarning ,
5271 )
5372 continue
73+ default_value = additional_data .pop (c_name , None )
5474 formatted [c_name ] = (
5575 Optional [OBBSecretStr ], # noqa
56- Field (default = None , description = c_origin , alias = c_name .upper ()),
76+ Field (
77+ default = default_value ,
78+ description = c_origin ,
79+ alias = c_name .upper (),
80+ ),
5781 )
5882
59- if additional :
60- for key , value in additional .items ():
83+ if additional_data :
84+ for key , value in additional_data .items ():
6185 if key in formatted :
6286 continue
6387 formatted [key ] = (
@@ -94,8 +118,6 @@ def from_providers(self) -> None:
94118
95119 def load (self ) -> BaseModel :
96120 """Load credentials from providers."""
97- # We load providers first to give them priority choosing credential names
98- _ = Env ()
99121 self .from_providers ()
100122 self .from_obbject ()
101123 path = Path (USER_SETTINGS_PATH )
@@ -107,34 +129,42 @@ def load(self) -> BaseModel:
107129 if "credentials" in data :
108130 additional = data ["credentials" ]
109131
110- # Collect all keys from providers to match with environment variables
132+ additional = self ._normalize_credential_map (additional )
133+
111134 all_keys = [
112135 key
113136 for keys in ProviderInterface ().credentials .values ()
114137 if keys
115138 for key in keys
116139 ]
117140
118- for key in all_keys :
119- if key .upper () in os .environ :
120- value = os .environ [key .upper ()]
121- if value :
122- additional [key ] = SecretStr (value )
141+ env_credentials : dict [str , SecretStr ] = {}
142+ for env_key , value in os .environ .items ():
143+ if not value :
144+ continue
145+ lower_key = env_key .lower ()
146+ if lower_key in all_keys or env_key .endswith ("API_KEY" ):
147+ canonical_key = lower_key if lower_key in all_keys else lower_key
148+ env_credentials [canonical_key ] = SecretStr (value )
123149
124- # Collect all environment variables ending with API_KEY
125- environ_keys = [ d for d in os . environ if d . endswith ( "API_KEY" )]
150+ if env_credentials :
151+ additional . update ( env_credentials )
126152
127- for key in environ_keys :
128- value = os .environ [key ]
129- if value :
130- additional [key .lower ()] = SecretStr (value )
153+ additional = self ._normalize_credential_map (additional )
154+
155+ env_overrides = {
156+ key : additional [key ]
157+ for key in env_credentials
158+ if key in additional and additional [key ] not in (None , "" )
159+ }
131160
132161 model = create_model (
133162 "Credentials" ,
134163 __config__ = ConfigDict (validate_assignment = True , populate_by_name = True ),
135164 ** self .format_credentials (additional ), # type: ignore
136165 )
137- model .origins = self .credentials
166+ model ._env_defaults = env_overrides # type: ignore # pylint: disable=W0212
167+
138168 return model
139169
140170
@@ -145,6 +175,29 @@ class Credentials(_Credentials): # type: ignore
145175 """Credentials model used to store provider credentials."""
146176
147177 model_config = ConfigDict (extra = "allow" )
178+ _env_defaults : ClassVar [dict [str , object ]] = getattr (
179+ _Credentials , "_env_defaults" , {}
180+ )
181+
182+ @staticmethod
183+ def _is_unset (value : object ) -> bool :
184+ if value is None :
185+ return True
186+ if isinstance (value , SecretStr ):
187+ return not value .get_secret_value ()
188+ if isinstance (value , str ):
189+ return value == ""
190+ return False
191+
192+ def model_post_init (self , __context ) -> None :
193+ """Set unset credentials from environment variables."""
194+ super ().model_post_init (__context )
195+ for key , secret in self ._env_defaults .items ():
196+ if key not in self .model_fields :
197+ continue
198+ current = getattr (self , key , None )
199+ if self ._is_unset (current ):
200+ setattr (self , key , secret )
148201
149202 def __repr__ (self ) -> str :
150203 """Define the string representation of the credentials."""
0 commit comments