Skip to content

Commit d4320b1

Browse files
committed
Merge branch 'develop' of https://github.com/OpenBB-finance/OpenBB into feature/openbb-cookiecutter
2 parents aa35d55 + ae86e04 commit d4320b1

File tree

13 files changed

+527
-37
lines changed

13 files changed

+527
-37
lines changed

openbb_platform/core/openbb_core/api/router/commands.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,11 @@
77
from typing import Annotated, Any, TypeVar
88

99
from fastapi import APIRouter, Depends, Header
10+
from fastapi.encoders import jsonable_encoder
11+
from fastapi.responses import JSONResponse
1012
from fastapi.routing import APIRoute
1113
from openbb_core.app.command_runner import CommandRunner
14+
from openbb_core.app.model.abstract.error import OpenBBError
1215
from openbb_core.app.model.command_context import CommandContext
1316
from openbb_core.app.model.obbject import OBBject
1417
from openbb_core.app.model.user_settings import UserSettings
@@ -140,7 +143,7 @@ def is_model(type_):
140143

141144
def exclude_fields_from_api(key: str, value: Any):
142145
type_ = type(value)
143-
field = c_out.model_fields.get(key, None)
146+
field = getattr(type(c_out), "model_fields", {}).get(key, None)
144147
json_schema_extra = field.json_schema_extra if field else None
145148

146149
# case where 1st layer field needs to be excluded
@@ -198,7 +201,9 @@ def build_api_wrapper(
198201
route.response_model = None
199202

200203
@wraps(wrapped=func)
201-
async def wrapper(*args: tuple[Any], **kwargs: dict[str, Any]) -> OBBject:
204+
async def wrapper(
205+
*args: tuple[Any], **kwargs: dict[str, Any]
206+
) -> OBBject | JSONResponse:
202207
user_settings: UserSettings = UserSettings.model_validate(
203208
kwargs.pop(
204209
"__authenticated_user_settings",
@@ -213,7 +218,7 @@ async def wrapper(*args: tuple[Any], **kwargs: dict[str, Any]) -> OBBject:
213218
)
214219

215220
if defaults:
216-
_provider = defaults.pop("provider", None)
221+
_ = defaults.pop("provider", None)
217222
standard_params = getattr(
218223
kwargs.pop("standard_params", None), "__dict__", {}
219224
)
@@ -243,6 +248,23 @@ async def wrapper(*args: tuple[Any], **kwargs: dict[str, Any]) -> OBBject:
243248
execute = partial(command_runner.run, path, user_settings)
244249
output = await execute(*args, **kwargs)
245250

251+
# This is where we check for `on_command_output` extensions
252+
mutated_output = getattr(output, "_extension_modified", False)
253+
results_only = getattr(output, "_results_only", False)
254+
try:
255+
if results_only is True:
256+
content = output.model_dump(exclude_unset=True).get("results", [])
257+
return JSONResponse(content=jsonable_encoder(content), status_code=200)
258+
259+
if (mutated_output and isinstance(output, OBBject)) or (
260+
isinstance(output, OBBject) and no_validate
261+
):
262+
return JSONResponse(content=jsonable_encoder(output), status_code=200)
263+
except Exception as exc: # pylint: disable=W0703
264+
raise OpenBBError(
265+
f"Error serializing output for an extension-modified endpoint {path}: {exc}",
266+
) from exc
267+
246268
if isinstance(output, OBBject) and not no_validate:
247269
return validate_output(output)
248270

openbb_platform/core/openbb_core/app/command_runner.py

Lines changed: 92 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,13 @@
66
from copy import deepcopy
77
from dataclasses import asdict, is_dataclass
88
from datetime import datetime
9-
from inspect import Parameter, signature
9+
from inspect import Parameter, iscoroutinefunction, signature
1010
from sys import exc_info
1111
from time import perf_counter_ns
1212
from typing import TYPE_CHECKING, Any, Optional
1313
from warnings import catch_warnings, showwarning, warn
1414

15+
from openbb_core.app.extension_loader import ExtensionLoader
1516
from openbb_core.app.model.abstract.error import OpenBBError
1617
from openbb_core.app.model.abstract.warning import OpenBBWarning, cast_warning
1718
from openbb_core.app.model.metadata import Metadata
@@ -272,6 +273,7 @@ def _chart(
272273

273274
if chart_params:
274275
kwargs.update(chart_params)
276+
275277
obbject.charting.show(render=False, **kwargs) # type: ignore[attr-defined]
276278
except Exception as e: # pylint: disable=broad-exception-caught
277279
if Env().DEBUG_MODE:
@@ -311,6 +313,9 @@ async def _execute_func( # pylint: disable=too-many-positional-arguments
311313
# in the charting extension then we add it there. This way we can remove
312314
# the chart parameter from the commands.py and package_builder, it will be
313315
# added to the function signature in the router decorator
316+
# If the ProviderInterface is not in use, we need to pass a copy of the
317+
# kwargs dictionary before it is validated, otherwise we lose those items.
318+
kwargs_copy = deepcopy(kwargs)
314319
chart = kwargs.pop("chart", False)
315320

316321
kwargs = ParametersBuilder.build(
@@ -348,7 +353,21 @@ async def _execute_func( # pylint: disable=too-many-positional-arguments
348353
extra_params
349354
)
350355
if chart and obbject.results:
351-
cls._chart(obbject, **kwargs)
356+
if "extra_params" not in kwargs_copy:
357+
kwargs_copy["extra_params"] = {}
358+
# Restore any kwargs passed that were removed by the ParametersBuilder
359+
for k in kwargs_copy.copy():
360+
if k == "chart":
361+
kwargs_copy.pop("chart", None)
362+
continue
363+
if (
364+
not extra_params or k not in extra_params
365+
) and k != "extra_params":
366+
kwargs_copy["extra_params"][k] = kwargs_copy.pop(
367+
k, None
368+
)
369+
370+
cls._chart(obbject, **kwargs_copy)
352371

353372
raised_warnings = warning_list if warning_list else []
354373
finally:
@@ -429,8 +448,79 @@ async def run(
429448
raise OpenBBError(e) from e
430449
warn(str(e), OpenBBWarning)
431450

451+
try:
452+
cls._trigger_command_output_callbacks(route, obbject)
453+
454+
except Exception as e:
455+
if Env().DEBUG_MODE:
456+
raise OpenBBError(e) from e
457+
warn(str(e), OpenBBWarning)
458+
432459
return obbject
433460

461+
@classmethod
462+
def _trigger_command_output_callbacks(cls, route: str, obbject: OBBject) -> None:
463+
"""Trigger command output callbacks for extensions."""
464+
loader = ExtensionLoader()
465+
callbacks = loader.on_command_output_callbacks
466+
results_only = False
467+
468+
# For each extension registered for all routes or the specific route,
469+
# we call its accessor on the OBBject.
470+
# We check if the accessor is immutable or not to decide whether to pass
471+
# a copy of the OBBject or the original one.
472+
# We set the _extension_modified attribute to True if any extension
473+
# mutates the OBBject so we can pass this information to the interface.
474+
# We also set the _results_only attribute to True if any extension
475+
# indicates that only results should be returned.
476+
if "*" in callbacks:
477+
for ext in callbacks["*"]:
478+
if ext.results_only is True:
479+
results_only = True
480+
if ext.immutable is True:
481+
if hasattr(obbject, ext.name):
482+
obbject_copy = deepcopy(obbject)
483+
accessor = getattr(obbject_copy, ext.name)
484+
if iscoroutinefunction(accessor):
485+
run_async(accessor)
486+
elif callable(accessor):
487+
accessor()
488+
elif ext.immutable is False:
489+
if ext.results_only is True:
490+
results_only = True
491+
if hasattr(obbject, ext.name):
492+
accessor = getattr(obbject, ext.name)
493+
if iscoroutinefunction(accessor):
494+
run_async(accessor)
495+
elif callable(accessor):
496+
accessor()
497+
setattr(obbject, "_extension_modified", True)
498+
499+
if route in callbacks:
500+
for ext in callbacks[route]:
501+
if ext.results_only is True:
502+
results_only = True
503+
504+
if ext.immutable is True:
505+
if hasattr(obbject, ext.name):
506+
obbject_copy = deepcopy(obbject)
507+
accessor = getattr(obbject_copy, ext.name)
508+
if iscoroutinefunction(accessor):
509+
run_async(accessor)
510+
elif callable(accessor):
511+
accessor()
512+
elif ext.immutable is False and hasattr(obbject, ext.name):
513+
accessor = getattr(obbject, ext.name)
514+
if iscoroutinefunction(accessor):
515+
run_async(accessor)
516+
elif callable(accessor):
517+
accessor()
518+
setattr(obbject, "_extension_modified", True)
519+
520+
if results_only is True:
521+
setattr(obbject, "_results_only", True)
522+
setattr(obbject, "_extension_modified", True)
523+
434524

435525
class CommandRunner:
436526
"""Command runner."""

openbb_platform/core/openbb_core/app/extension_loader.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,23 @@ def __init__(
4949
self._obbject_objects: dict[str, Extension] = {}
5050
self._core_objects: dict[str, Router] = {}
5151
self._provider_objects: dict[str, Provider] = {}
52+
self._on_command_output_callbacks: dict[str, list[Extension]] = {}
53+
self._register_command_output_callbacks()
54+
55+
@property
56+
def on_command_output_callbacks(self) -> dict[str, list[Extension]]:
57+
"""Return the on command output callbacks."""
58+
return self._on_command_output_callbacks
59+
60+
def _register_command_output_callbacks(self) -> None:
61+
"""Register extensions that act on command output."""
62+
for ext in self.obbject_objects.values():
63+
if ext.on_command_output:
64+
paths = ext.command_output_paths or ["*"]
65+
for path in paths:
66+
if path not in self._on_command_output_callbacks:
67+
self._on_command_output_callbacks[path] = []
68+
self._on_command_output_callbacks[path].append(ext)
5269

5370
@property
5471
def obbject_entry_points(self) -> EntryPoints:
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
"""OpenBB Core App Abstract Model Tagged."""
22

33
from pydantic import BaseModel, Field
4-
from uuid_extensions import uuid7str # type: ignore
4+
from uuid_extensions import uuid7str
55

66

77
class Tagged(BaseModel):
88
"""Model for Tagged."""
99

10-
id: str = Field(default_factory=uuid7str, alias="_id")
10+
id: str = Field(default_factory=uuid7str)

openbb_platform/core/openbb_core/app/model/extension.py

Lines changed: 65 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,35 +8,96 @@ class Extension:
88
"""
99
Serves as OBBject extension entry point and must be created by each extension package.
1010
11-
See https://docs.openbb.co/platform/development/developer-guidelines/obbject_extensions.
11+
See https://docs.openbb.co/developer/extension_types/obbject for more information.
1212
"""
1313

14+
# pylint: disable=R0917
1415
def __init__(
1516
self,
1617
name: str,
1718
credentials: list[str] | None = None,
1819
description: str | None = None,
20+
on_command_output: bool = False,
21+
command_output_paths: list[str] | None = None,
22+
immutable: bool = True,
23+
results_only: bool = False,
1924
) -> None:
2025
"""Initialize the extension.
2126
2227
Parameters
2328
----------
2429
name : str
2530
Name of the extension.
26-
credentials : Optional[List[str]], optional
31+
credentials : list[str], optional
2732
List of required credentials, by default None
2833
description: Optional[str]
2934
Extension description.
35+
on_command_output : bool, optional
36+
Whether the extension acts on command output, by default False
37+
command_output_paths : list[str], optional
38+
List of endpoint paths the extension acts on, where None means all, by default None.
39+
immutable : bool, optional
40+
Whether the function output is immutable, by default True.
41+
results_only : bool, optional
42+
Whether the extension returns only the results instead of the OBBject, by default False.
3043
"""
44+
# pylint: disable=import-outside-toplevel
45+
from openbb_core.app.service.system_service import SystemService
46+
3147
self.name = name
3248
self.credentials = credentials or []
3349
self.description = description
50+
self.on_command_output = on_command_output
51+
self.command_output_paths = command_output_paths or []
52+
self.immutable = immutable
53+
self.results_only = results_only
54+
55+
# This must be explicitly enabled.
56+
if self.on_command_output is False and (
57+
self.command_output_paths
58+
or self.results_only is True
59+
or self.immutable is False
60+
):
61+
raise ValueError(
62+
"OBBject Extension Error -> 'on_command_output' must be set as True when"
63+
+ " 'command_output_paths', 'results_only' or 'immutable' is set.",
64+
)
65+
66+
# The user must explicitly enable OBBject extensions that act on command output.
67+
if (
68+
self.on_command_output
69+
and not SystemService().system_settings.allow_on_command_output
70+
):
71+
raise RuntimeError(
72+
"OBBject Extension Error -> \n\n"
73+
+ "An OBBject extension that acts on command output is installed "
74+
+ "but has not been enabled in `system_settings.json`.\n\n"
75+
+ "Set `allow_on_command_output` to True to enable it.\n"
76+
+ "Or, set the environment variable `OPENBB_ALLOW_ON_COMMAND_OUTPUT` to True."
77+
+ "\n\nProceed with caution as this may have security implications.\n\n"
78+
+ "Ensure the extension is installed from a trusted source.\n\n",
79+
)
80+
81+
# The user must explicitly enable OBBject extensions that modify output.
82+
if (
83+
self.on_command_output
84+
and self.immutable is False
85+
and not SystemService().system_settings.allow_mutable_extensions
86+
):
87+
raise RuntimeError(
88+
"OBBject Extension Error -> \n\n"
89+
+ "An OBBject extension that modifies the output is installed "
90+
+ "but has not been enabled in `system_settings.json`.\n\n"
91+
+ "Set `allow_mutable_extensions` to True to enable it.\n"
92+
+ "Or, set the environment variable `OPENBB_ALLOW_MUTABLE_EXTENSIONS` to True."
93+
+ "\n\nProceed with caution as this may have security implications.\n\n"
94+
+ "Ensure the extension is installed from a trusted source.\n\n",
95+
)
3496

3597
@property
3698
def obbject_accessor(self) -> Callable:
3799
"""Extend an OBBject, inspired by pandas."""
38100
# pylint: disable=import-outside-toplevel
39-
# Avoid circular imports
40101

41102
from openbb_core.app.model.obbject import OBBject
42103

@@ -55,8 +116,8 @@ def decorator(accessor):
55116
UserWarning,
56117
)
57118
setattr(cls, name, CachedAccessor(name, accessor))
58-
# pylint: disable=protected-access
59119
cls.accessors.add(name)
120+
60121
return accessor
61122

62123
return decorator

openbb_platform/core/openbb_core/app/model/obbject.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ class OBBject(Tagged, Generic[T]):
6060
default_factory=dict,
6161
description="Extra info.",
6262
)
63-
_route: str = PrivateAttr(
63+
_route: str | None = PrivateAttr(
6464
default=None,
6565
)
6666
_standard_params: dict[str, Any] | None = PrivateAttr(
@@ -178,9 +178,9 @@ def is_list_of_basemodel(items: list[T] | T) -> bool:
178178

179179
# BaseModel
180180
if isinstance(res, BaseModel):
181-
res_dict = res.model_dump(
181+
res_dict = res.model_dump( # pylint: disable=no-member
182182
exclude_unset=True, exclude_none=True
183-
) # pylint: disable=no-member
183+
)
184184
# Model is serialized as a dict[str, list] or list[dict]
185185
if (
186186
(
@@ -322,13 +322,16 @@ def to_dict(
322322
orient == "list"
323323
and isinstance(self.results, dict)
324324
and all(
325-
isinstance(value, dict) for value in self.results.values()
326-
) # pylint: disable=no-member
325+
isinstance(value, dict)
326+
for value in self.results.values() # pylint: disable=no-member
327+
)
327328
):
328329
df = df.T
329-
results = df.to_dict(orient=orient)
330+
results: dict | list = df.to_dict(orient=orient)
331+
330332
if isinstance(results, dict) and orient == "list" and "index" in results:
331333
del results["index"]
334+
332335
return results
333336

334337
def to_llm(self) -> dict[Hashable, Any] | list[dict[Hashable, Any]]:

0 commit comments

Comments
 (0)