diff --git a/.pylintrc b/.pylintrc index 554f953b8e..6f86387a2b 100644 --- a/.pylintrc +++ b/.pylintrc @@ -414,7 +414,7 @@ max-bool-expr=5 max-branches=15 # Maximum number of locals for function / method body -max-locals=20 +max-locals=25 # Maximum number of parents for a class (see R0901). max-parents=7 diff --git a/dash/_callback.py b/dash/_callback.py index 8ebc274213..0c7476fda1 100644 --- a/dash/_callback.py +++ b/dash/_callback.py @@ -1,6 +1,7 @@ import collections import hashlib from functools import wraps +from typing import Callable, Optional, Any import flask @@ -67,6 +68,7 @@ def callback( cancel=None, manager=None, cache_args_to_ignore=None, + on_error: Optional[Callable[[Exception], Any]] = None, **_kwargs, ): """ @@ -137,6 +139,10 @@ def callback( this should be a list of argument indices as integers. :param interval: Time to wait between the long callback update requests. + :param on_error: + Function to call when the callback raises an exception. Receives the + exception object as first argument. The callback_context can be used + to access the original callback inputs, states and output. """ long_spec = None @@ -186,6 +192,7 @@ def callback( long=long_spec, manager=manager, running=running, + on_error=on_error, ) @@ -226,7 +233,7 @@ def insert_callback( long=None, manager=None, running=None, - dynamic_creator=False, + dynamic_creator: Optional[bool] = False, no_output=False, ): if prevent_initial_call is None: @@ -272,8 +279,16 @@ def insert_callback( return callback_id -# pylint: disable=R0912, R0915 -def register_callback( # pylint: disable=R0914 +def _set_side_update(ctx, response) -> bool: + side_update = dict(ctx.updated_props) + if len(side_update) > 0: + response["sideUpdate"] = side_update + return True + return False + + +# pylint: disable=too-many-branches,too-many-statements +def register_callback( callback_list, callback_map, config_prevent_initial_callbacks, *_args, **_kwargs ): ( @@ -297,6 +312,7 @@ def register_callback( # pylint: disable=R0914 long = _kwargs.get("long") manager = _kwargs.get("manager") running = _kwargs.get("running") + on_error = _kwargs.get("on_error") if running is not None: if not isinstance(running[0], (list, tuple)): running = [running] @@ -342,6 +358,8 @@ def add_context(*args, **kwargs): "callback_context", AttributeDict({"updated_props": {}}) ) callback_manager = long and long.get("manager", app_callback_manager) + error_handler = on_error or kwargs.pop("app_on_error", None) + if has_output: _validate.validate_output_spec(insert_output, output_spec, Output) @@ -351,7 +369,7 @@ def add_context(*args, **kwargs): args, inputs_state_indices ) - response = {"multi": True} + response: dict = {"multi": True} has_update = False if long is not None: @@ -440,10 +458,24 @@ def add_context(*args, **kwargs): isinstance(output_value, dict) and "long_callback_error" in output_value ): - error = output_value.get("long_callback_error") - raise LongCallbackError( + error = output_value.get("long_callback_error", {}) + exc = LongCallbackError( f"An error occurred inside a long callback: {error['msg']}\n{error['tb']}" ) + if error_handler: + output_value = error_handler(exc) + + if output_value is None: + output_value = NoUpdate() + # set_props from the error handler uses the original ctx + # instead of manager.get_updated_props since it runs in the + # request process. + has_update = ( + _set_side_update(callback_ctx, response) + or output_value is not None + ) + else: + raise exc if job_running and output_value is not callback_manager.UNDEFINED: # cached results. @@ -462,10 +494,22 @@ def add_context(*args, **kwargs): if output_value is callback_manager.UNDEFINED: return to_json(response) else: - output_value = _invoke_callback(func, *func_args, **func_kwargs) - - if NoUpdate.is_no_update(output_value): - raise PreventUpdate + try: + output_value = _invoke_callback(func, *func_args, **func_kwargs) + except PreventUpdate as err: + raise err + except Exception as err: # pylint: disable=broad-exception-caught + if error_handler: + output_value = error_handler(err) + + # If the error returns nothing, automatically puts NoUpdate for response. + if output_value is None: + if not multi: + output_value = NoUpdate() + else: + output_value = [NoUpdate for _ in output_spec] + else: + raise err component_ids = collections.defaultdict(dict) @@ -487,12 +531,12 @@ def add_context(*args, **kwargs): ) for val, spec in zip(flat_output_values, output_spec): - if isinstance(val, NoUpdate): + if NoUpdate.is_no_update(val): continue for vali, speci in ( zip(val, spec) if isinstance(spec, list) else [[val, spec]] ): - if not isinstance(vali, NoUpdate): + if not NoUpdate.is_no_update(vali): has_update = True id_str = stringify_id(speci["id"]) prop = clean_property_name(speci["property"]) @@ -506,10 +550,7 @@ def add_context(*args, **kwargs): flat_output_values = [] if not long: - side_update = dict(callback_ctx.updated_props) - if len(side_update) > 0: - has_update = True - response["sideUpdate"] = side_update + has_update = _set_side_update(callback_ctx, response) or has_update if not has_update: raise PreventUpdate diff --git a/dash/dash.py b/dash/dash.py index 3e5e4e4170..f499189838 100644 --- a/dash/dash.py +++ b/dash/dash.py @@ -16,7 +16,7 @@ import base64 import traceback from urllib.parse import urlparse -from typing import Dict, Optional, Union +from typing import Any, Callable, Dict, Optional, Union import flask @@ -369,6 +369,10 @@ class Dash: :param description: Sets a default description for meta tags on Dash pages (use_pages=True). + :param on_error: Global callback error handler to call when + an exception is raised. Receives the exception object as first argument. + The callback_context can be used to access the original callback inputs, + states and output. """ _plotlyjs_url: str @@ -409,6 +413,7 @@ def __init__( # pylint: disable=too-many-statements hooks: Union[RendererHooks, None] = None, routing_callback_inputs: Optional[Dict[str, Union[Input, State]]] = None, description=None, + on_error: Optional[Callable[[Exception], Any]] = None, **obsolete, ): _validate.check_obsolete(obsolete) @@ -520,6 +525,7 @@ def __init__( # pylint: disable=too-many-statements self._layout = None self._layout_is_function = False self.validation_layout = None + self._on_error = on_error self._extra_components = [] self._setup_dev_tools() @@ -1377,6 +1383,7 @@ def dispatch(self): outputs_list=outputs_list, long_callback_manager=self._background_manager, callback_context=g, + app_on_error=self._on_error, ) ) ) diff --git a/dash/long_callback/managers/celery_manager.py b/dash/long_callback/managers/celery_manager.py index 01fadf4f8d..612cc245fb 100644 --- a/dash/long_callback/managers/celery_manager.py +++ b/dash/long_callback/managers/celery_manager.py @@ -161,6 +161,7 @@ def run(): c.ignore_register_page = False c.updated_props = ProxySetProps(_set_props) context_value.set(c) + errored = False try: if isinstance(user_callback_args, dict): user_callback_output = fn(*maybe_progress, **user_callback_args) @@ -170,6 +171,7 @@ def run(): user_callback_output = fn(*maybe_progress, user_callback_args) except PreventUpdate: # Put NoUpdate dict directly to avoid circular imports. + errored = True cache.set( result_key, json.dumps( @@ -177,6 +179,7 @@ def run(): ), ) except Exception as err: # pylint: disable=broad-except + errored = True cache.set( result_key, json.dumps( @@ -188,7 +191,8 @@ def run(): }, ), ) - else: + + if not errored: cache.set( result_key, json.dumps(user_callback_output, cls=PlotlyJSONEncoder) ) diff --git a/dash/long_callback/managers/diskcache_manager.py b/dash/long_callback/managers/diskcache_manager.py index a106700f82..e1a110f14f 100644 --- a/dash/long_callback/managers/diskcache_manager.py +++ b/dash/long_callback/managers/diskcache_manager.py @@ -121,7 +121,8 @@ def call_job_fn(self, key, job_fn, args, context): # pylint: disable-next=not-callable proc = Process( - target=job_fn, args=(key, self._make_progress_key(key), args, context) + target=job_fn, + args=(key, self._make_progress_key(key), args, context), ) proc.start() return proc.pid @@ -187,6 +188,7 @@ def run(): c.ignore_register_page = False c.updated_props = ProxySetProps(_set_props) context_value.set(c) + errored = False try: if isinstance(user_callback_args, dict): user_callback_output = fn(*maybe_progress, **user_callback_args) @@ -195,8 +197,10 @@ def run(): else: user_callback_output = fn(*maybe_progress, user_callback_args) except PreventUpdate: + errored = True cache.set(result_key, {"_dash_no_update": "_dash_no_update"}) except Exception as err: # pylint: disable=broad-except + errored = True cache.set( result_key, { @@ -206,7 +210,8 @@ def run(): } }, ) - else: + + if not errored: cache.set(result_key, user_callback_output) ctx.run(run) diff --git a/tests/integration/callbacks/test_callback_error.py b/tests/integration/callbacks/test_callback_error.py new file mode 100644 index 0000000000..0a76ed741d --- /dev/null +++ b/tests/integration/callbacks/test_callback_error.py @@ -0,0 +1,46 @@ +from dash import Dash, html, Input, Output, set_props + + +def test_cber001_error_handler(dash_duo): + def global_callback_error_handler(err): + set_props("output-global", {"children": f"global: {err}"}) + + app = Dash(on_error=global_callback_error_handler) + + app.layout = [ + html.Button("start", id="start-local"), + html.Button("start-global", id="start-global"), + html.Div(id="output"), + html.Div(id="output-global"), + html.Div(id="error-message"), + ] + + def on_callback_error(err): + set_props("error-message", {"children": f"message: {err}"}) + return f"callback: {err}" + + @app.callback( + Output("output", "children"), + Input("start-local", "n_clicks"), + on_error=on_callback_error, + prevent_initial_call=True, + ) + def on_start(_): + raise Exception("local error") + + @app.callback( + Output("output-global", "children"), + Input("start-global", "n_clicks"), + prevent_initial_call=True, + ) + def on_start_global(_): + raise Exception("global error") + + dash_duo.start_server(app) + dash_duo.find_element("#start-local").click() + + dash_duo.wait_for_text_to_equal("#output", "callback: local error") + dash_duo.wait_for_text_to_equal("#error-message", "message: local error") + + dash_duo.find_element("#start-global").click() + dash_duo.wait_for_text_to_equal("#output-global", "global: global error") diff --git a/tests/integration/long_callback/app_bg_on_error.py b/tests/integration/long_callback/app_bg_on_error.py new file mode 100644 index 0000000000..3132b6d8a0 --- /dev/null +++ b/tests/integration/long_callback/app_bg_on_error.py @@ -0,0 +1,50 @@ +from dash import Dash, Input, Output, html, set_props +from tests.integration.long_callback.utils import get_long_callback_manager + +long_callback_manager = get_long_callback_manager() +handle = long_callback_manager.handle + + +def global_error_handler(err): + set_props("global-output", {"children": f"global: {err}"}) + + +app = Dash( + __name__, long_callback_manager=long_callback_manager, on_error=global_error_handler +) + +app.layout = [ + html.Button("callback on_error", id="start-cb-onerror"), + html.Div(id="cb-output"), + html.Button("global on_error", id="start-global-onerror"), + html.Div(id="global-output"), +] + + +def callback_on_error(err): + set_props("cb-output", {"children": f"callback: {err}"}) + + +@app.callback( + Output("cb-output", "children"), + Input("start-cb-onerror", "n_clicks"), + prevent_initial_call=True, + background=True, + on_error=callback_on_error, +) +def on_click(_): + raise Exception("callback error") + + +@app.callback( + Output("global-output", "children"), + Input("start-global-onerror", "n_clicks"), + prevent_initial_call=True, + background=True, +) +def on_click_global(_): + raise Exception("global error") + + +if __name__ == "__main__": + app.run(debug=True) diff --git a/tests/integration/long_callback/test_basic_long_callback018.py b/tests/integration/long_callback/test_basic_long_callback018.py new file mode 100644 index 0000000000..7dd0ca36d7 --- /dev/null +++ b/tests/integration/long_callback/test_basic_long_callback018.py @@ -0,0 +1,13 @@ +from tests.integration.long_callback.utils import setup_long_callback_app + + +def test_lcbc018_background_callback_on_error(dash_duo, manager): + with setup_long_callback_app(manager, "app_bg_on_error") as app: + dash_duo.start_server(app) + + dash_duo.find_element("#start-cb-onerror").click() + + dash_duo.wait_for_contains_text("#cb-output", "callback error") + + dash_duo.find_element("#start-global-onerror").click() + dash_duo.wait_for_contains_text("#global-output", "global error") diff --git a/tests/integration/long_callback/utils.py b/tests/integration/long_callback/utils.py index 83b804848c..8c5c9e2a4f 100644 --- a/tests/integration/long_callback/utils.py +++ b/tests/integration/long_callback/utils.py @@ -18,7 +18,13 @@ def __init__(self, cache=None, cache_by=None, expire=None): super().__init__(cache=cache, cache_by=cache_by, expire=expire) self.running_jobs = [] - def call_job_fn(self, key, job_fn, args, context): + def call_job_fn( + self, + key, + job_fn, + args, + context, + ): pid = super().call_job_fn(key, job_fn, args, context) self.running_jobs.append(pid) return pid @@ -135,8 +141,9 @@ def setup_long_callback_app(manager_name, app_name): # Sleep for a couple of intervals time.sleep(2.0) - for job in manager.running_jobs: - manager.terminate_job(job) + if hasattr(manager, "running_jobs"): + for job in manager.running_jobs: + manager.terminate_job(job) shutil.rmtree(cache_directory, ignore_errors=True) os.environ.pop("LONG_CALLBACK_MANAGER")