diff --git a/.github/workflows/testing.yml b/.github/workflows/testing.yml
index 1fc0df1845..be5caf4929 100644
--- a/.github/workflows/testing.yml
+++ b/.github/workflows/testing.yml
@@ -17,6 +17,7 @@ jobs:
# This output will be 'true' if files in the 'table_related_paths' list changed, 'false' otherwise.
table_paths_changed: ${{ steps.filter.outputs.table_related_paths }}
background_cb_changed: ${{ steps.filter.outputs.background_paths }}
+ backend_cb_changed: ${{ steps.filter.outputs.backend_paths }}
steps:
- name: Checkout repository
uses: actions/checkout@v4
@@ -37,6 +38,9 @@ jobs:
- 'tests/background_callback/**'
- 'tests/async_tests/**'
- 'requirements/**'
+ backend_paths:
+ - 'dash/backend/**'
+ - 'tests/backend/**'
build:
name: Build Dash Package
@@ -271,6 +275,109 @@ jobs:
cd bgtests
pytest --headless --nopercyfinalize tests/async_tests -v -s
+ backend-tests:
+ name: Run Backend Callback Tests (Python ${{ matrix.python-version }})
+ needs: [build, changes_filter]
+ if: |
+ (github.event_name == 'push' && (github.ref == 'refs/heads/master' || github.ref == 'refs/heads/dev')) ||
+ needs.changes_filter.outputs.backend_cb_changed == 'true'
+ timeout-minutes: 30
+ runs-on: ubuntu-latest
+ strategy:
+ fail-fast: false
+ matrix:
+ python-version: ["3.9", "3.12"]
+
+ services:
+ redis:
+ image: redis:6
+ ports:
+ - 6379:6379
+ options: >-
+ --health-cmd "redis-cli ping"
+ --health-interval 10s
+ --health-timeout 5s
+ --health-retries 5
+
+ env:
+ REDIS_URL: redis://localhost:6379
+ steps:
+ - name: Checkout repository
+ uses: actions/checkout@v4
+
+ - name: Set up Node.js
+ uses: actions/setup-node@v4
+ with:
+ node-version: '20'
+ cache: 'npm'
+
+ - name: Install Node.js dependencies
+ run: npm ci
+
+ - name: Set up Python ${{ matrix.python-version }}
+ uses: actions/setup-python@v5
+ with:
+ python-version: ${{ matrix.python-version }}
+ cache: 'pip'
+
+ - name: Download built Dash packages
+ uses: actions/download-artifact@v4
+ with:
+ name: dash-packages
+ path: packages/
+
+ - name: Install Dash packages
+ run: |
+ python -m pip install --upgrade pip wheel
+ python -m pip install "setuptools<78.0.0"
+ python -m pip install "selenium==4.32.0"
+ find packages -name dash-*.whl -print -exec sh -c 'pip install "{}[async,ci,testing,dev,celery,diskcache,fastapi,quart]"' \;
+
+ - name: Install Google Chrome
+ run: |
+ sudo apt-get update
+ sudo apt-get install -y google-chrome-stable
+
+ - name: Install ChromeDriver
+ run: |
+ echo "Determining Chrome version..."
+ CHROME_BROWSER_VERSION=$(google-chrome --version)
+ echo "Installed Chrome Browser version: $CHROME_BROWSER_VERSION"
+ CHROME_MAJOR_VERSION=$(echo "$CHROME_BROWSER_VERSION" | cut -f 3 -d ' ' | cut -f 1 -d '.')
+ echo "Detected Chrome Major version: $CHROME_MAJOR_VERSION"
+ if [ "$CHROME_MAJOR_VERSION" -ge 115 ]; then
+ echo "Fetching ChromeDriver version for Chrome $CHROME_MAJOR_VERSION using CfT endpoint..."
+ CHROMEDRIVER_VERSION_STRING=$(curl -sS "https://googlechromelabs.github.io/chrome-for-testing/LATEST_RELEASE_${CHROME_MAJOR_VERSION}")
+ if [ -z "$CHROMEDRIVER_VERSION_STRING" ]; then
+ echo "Could not automatically find ChromeDriver version for Chrome $CHROME_MAJOR_VERSION via LATEST_RELEASE. Please check CfT endpoints."
+ exit 1
+ fi
+ CHROMEDRIVER_URL="https://edgedl.me.gvt1.com/edgedl/chrome/chrome-for-testing/${CHROMEDRIVER_VERSION_STRING}/linux64/chromedriver-linux64.zip"
+ else
+ echo "Fetching ChromeDriver version for Chrome $CHROME_MAJOR_VERSION using older method..."
+ CHROMEDRIVER_VERSION_STRING=$(curl -sS "https://chromedriver.storage.googleapis.com/LATEST_RELEASE_${CHROME_MAJOR_VERSION}")
+ CHROMEDRIVER_URL="https://chromedriver.storage.googleapis.com/${CHROMEDRIVER_VERSION_STRING}/chromedriver_linux64.zip"
+ fi
+ echo "Using ChromeDriver version string: $CHROMEDRIVER_VERSION_STRING"
+ echo "Downloading ChromeDriver from: $CHROMEDRIVER_URL"
+ wget -q -O chromedriver.zip "$CHROMEDRIVER_URL"
+ unzip -o chromedriver.zip -d /tmp/
+ sudo mv /tmp/chromedriver-linux64/chromedriver /usr/local/bin/chromedriver || sudo mv /tmp/chromedriver /usr/local/bin/chromedriver
+ sudo chmod +x /usr/local/bin/chromedriver
+ echo "/usr/local/bin" >> $GITHUB_PATH
+ shell: bash
+
+ - name: Build/Setup test components
+ run: npm run setup-tests.py
+
+ - name: Run Backend Callback Tests
+ run: |
+ mkdir bgtests
+ cp -r tests bgtests/tests
+ cd bgtests
+ touch __init__.py
+ pytest --headless --nopercyfinalize tests/backend_tests -v -s
+
table-unit:
name: Table Unit/Lint Tests (Python ${{ matrix.python-version }})
needs: [build, changes_filter]
diff --git a/dash/_callback.py b/dash/_callback.py
index aacb8dbdde..6cc55b9162 100644
--- a/dash/_callback.py
+++ b/dash/_callback.py
@@ -6,7 +6,7 @@
import asyncio
-import flask
+from dash.backend import get_request_adapter
from .dependencies import (
handle_callback_args,
@@ -376,7 +376,7 @@ def _get_callback_manager(
" and store results on redis.\n"
)
- old_job = flask.request.args.getlist("oldJob")
+ old_job = get_request_adapter().get_args().getlist("oldJob")
if old_job:
for job in old_job:
@@ -436,7 +436,7 @@ def _setup_background_callback(
def _progress_background_callback(response, callback_manager, background):
progress_outputs = background.get("progress")
- cache_key = flask.request.args.get("cacheKey")
+ cache_key = get_request_adapter().get_args().get("cacheKey")
if progress_outputs:
# Get the progress before the result as it would be erased after the results.
@@ -453,8 +453,8 @@ def _update_background_callback(
"""Set up the background callback and manage jobs."""
callback_manager = _get_callback_manager(kwargs, background)
- cache_key = flask.request.args.get("cacheKey")
- job_id = flask.request.args.get("job")
+ cache_key = get_request_adapter().get_args().get("cacheKey")
+ job_id = get_request_adapter().get_args().get("job")
_progress_background_callback(response, callback_manager, background)
@@ -474,8 +474,8 @@ def _handle_rest_background_callback(
multi,
has_update=False,
):
- cache_key = flask.request.args.get("cacheKey")
- job_id = flask.request.args.get("job")
+ cache_key = get_request_adapter().get_args().get("cacheKey")
+ job_id = get_request_adapter().get_args().get("job")
# Must get job_running after get_result since get_results terminates it.
job_running = callback_manager.job_running(job_id)
if not job_running and output_value is callback_manager.UNDEFINED:
@@ -688,11 +688,10 @@ def add_context(*args, **kwargs):
)
response: dict = {"multi": True}
-
jsonResponse = None
try:
if background is not None:
- if not flask.request.args.get("cacheKey"):
+ if not get_request_adapter().get_args().get("cacheKey"):
return _setup_background_callback(
kwargs,
background,
@@ -763,7 +762,7 @@ async def async_add_context(*args, **kwargs):
try:
if background is not None:
- if not flask.request.args.get("cacheKey"):
+ if not get_request_adapter().get_args().get("cacheKey"):
return _setup_background_callback(
kwargs,
background,
diff --git a/dash/_callback_context.py b/dash/_callback_context.py
index f64865c464..72b92e09e2 100644
--- a/dash/_callback_context.py
+++ b/dash/_callback_context.py
@@ -288,6 +288,14 @@ def path(self):
"""
return _get_from_context("path", "")
+ @property
+ @has_context
+ def args(self):
+ """
+ Query parameters of the callback request as a dictionary-like object.
+ """
+ return _get_from_context("args", "")
+
@property
@has_context
def remote(self):
diff --git a/dash/_pages.py b/dash/_pages.py
index 45538546e8..acb26e8791 100644
--- a/dash/_pages.py
+++ b/dash/_pages.py
@@ -389,15 +389,15 @@ def _path_to_page(path_id):
return {}, None
-def _page_meta_tags(app):
- start_page, path_variables = _path_to_page(flask.request.path.strip("/"))
+def _page_meta_tags(app, request):
+ request_path = request.get_path()
+ start_page, path_variables = _path_to_page(request_path.strip("/"))
- # use the supplied image_url or create url based on image in the assets folder
image = start_page.get("image", "")
if image:
image = app.get_asset_url(image)
assets_image_url = (
- "".join([flask.request.url_root, image.lstrip("/")]) if image else None
+ "".join([request.get_root(), image.lstrip("/")]) if image else None
)
supplied_image_url = start_page.get("image_url")
image_url = supplied_image_url if supplied_image_url else assets_image_url
@@ -413,7 +413,7 @@ def _page_meta_tags(app):
return [
{"name": "description", "content": description},
{"property": "twitter:card", "content": "summary_large_image"},
- {"property": "twitter:url", "content": flask.request.url},
+ {"property": "twitter:url", "content": request.get_url()},
{"property": "twitter:title", "content": title},
{"property": "twitter:description", "content": description},
{"property": "twitter:image", "content": image_url or ""},
diff --git a/dash/_utils.py b/dash/_utils.py
index f118e61538..ef6c63c281 100644
--- a/dash/_utils.py
+++ b/dash/_utils.py
@@ -104,6 +104,11 @@ def set_read_only(self, names, msg="Attribute is read-only"):
else:
object.__setattr__(self, "_read_only", new_read_only)
+ def unset_read_only(self, keys):
+ if hasattr(self, "_read_only"):
+ for key in keys:
+ self._read_only.pop(key, None)
+
def finalize(self, msg="Object is final: No new keys may be added."):
"""Prevent any new keys being set."""
object.__setattr__(self, "_final", msg)
diff --git a/dash/backend/__init__.py b/dash/backend/__init__.py
new file mode 100644
index 0000000000..eb1d47bc3f
--- /dev/null
+++ b/dash/backend/__init__.py
@@ -0,0 +1,15 @@
+# python
+import contextvars
+from .registry import get_backend # pylint: disable=unused-import
+
+__all__ = ["set_request_adapter", "get_request_adapter", "get_backend"]
+
+_request_adapter_var = contextvars.ContextVar("request_adapter")
+
+
+def set_request_adapter(adapter):
+ _request_adapter_var.set(adapter)
+
+
+def get_request_adapter():
+ return _request_adapter_var.get()
diff --git a/dash/backend/base_server.py b/dash/backend/base_server.py
new file mode 100644
index 0000000000..4855f86ad6
--- /dev/null
+++ b/dash/backend/base_server.py
@@ -0,0 +1,58 @@
+from abc import ABC, abstractmethod
+from typing import Any
+
+
+class BaseDashServer(ABC):
+ def __call__(self, server, *args, **kwargs) -> Any:
+ # Default: WSGI
+ return server(*args, **kwargs)
+
+ @abstractmethod
+ def create_app(
+ self, name: str = "__main__", config=None
+ ) -> Any: # pragma: no cover - interface
+ pass
+
+ @abstractmethod
+ def register_assets_blueprint(
+ self, app, blueprint_name: str, assets_url_path: str, assets_folder: str
+ ) -> None: # pragma: no cover - interface
+ pass
+
+ @abstractmethod
+ def register_error_handlers(self, app) -> None: # pragma: no cover - interface
+ pass
+
+ @abstractmethod
+ def add_url_rule(
+ self, app, rule: str, view_func, endpoint=None, methods=None
+ ) -> None: # pragma: no cover - interface
+ pass
+
+ @abstractmethod
+ def before_request(self, app, func) -> None: # pragma: no cover - interface
+ pass
+
+ @abstractmethod
+ def after_request(self, app, func) -> None: # pragma: no cover - interface
+ pass
+
+ @abstractmethod
+ def run(
+ self, app, host: str, port: int, debug: bool, **kwargs
+ ) -> None: # pragma: no cover - interface
+ pass
+
+ @abstractmethod
+ def make_response(
+ self, data, mimetype=None, content_type=None
+ ) -> Any: # pragma: no cover - interface
+ pass
+
+ @abstractmethod
+ def jsonify(self, obj) -> Any: # pragma: no cover - interface
+ pass
+
+ @abstractmethod
+ def get_request_adapter(self) -> Any: # pragma: no cover - interface
+ pass
diff --git a/dash/backend/fastapi.py b/dash/backend/fastapi.py
new file mode 100644
index 0000000000..8c402cb187
--- /dev/null
+++ b/dash/backend/fastapi.py
@@ -0,0 +1,511 @@
+import sys
+import mimetypes
+import hashlib
+import inspect
+import pkgutil
+from contextvars import copy_context
+import importlib.util
+import time
+import traceback
+import re
+
+try:
+ import uvicorn
+ from fastapi import FastAPI, Request, Response
+ from fastapi.responses import JSONResponse
+ from fastapi.staticfiles import StaticFiles
+ from starlette.responses import Response as StarletteResponse
+ from starlette.datastructures import MutableHeaders
+ from pydantic import create_model
+ from typing import Any, Optional
+except ImportError:
+ uvicorn = None
+ FastAPI = None
+ Request = None
+ Response = None
+ JSONResponse = None
+ StaticFiles = None
+ StarletteResponse = None
+ MutableHeaders = None
+ create_model = None
+ Any = None
+ Optional = None
+
+
+import json
+import os
+from dash.fingerprint import check_fingerprint
+from dash import _validate
+from dash.exceptions import (
+ PreventUpdate,
+)
+from dash.backend import set_request_adapter
+from .base_server import BaseDashServer
+
+CONFIG_PATH = "dash_config.json"
+
+
+def save_config(config):
+ with open(CONFIG_PATH, "w") as f:
+ json.dump(config, f)
+
+
+def load_config():
+ if os.path.exists(CONFIG_PATH):
+ with open(CONFIG_PATH, "r") as f:
+ return json.load(f)
+ return {}
+
+
+class FastAPIDashServer(BaseDashServer):
+ def __init__(self):
+ self.error_handling_mode = "prune"
+ super().__init__()
+
+ def __call__(self, server, *args, **kwargs):
+ # ASGI: (scope, receive, send)
+ if len(args) == 3 and isinstance(args[0], dict) and "type" in args[0]:
+ return server(*args, **kwargs)
+ raise TypeError("FastAPI app must be called with (scope, receive, send)")
+
+ def create_app(self, name="__main__", config=None):
+ app = FastAPI()
+ if config:
+ for key, value in config.items():
+ setattr(app.state, key, value)
+ return app
+
+ def register_assets_blueprint(
+ self, app, blueprint_name, assets_url_path, assets_folder
+ ):
+ try:
+ app.mount(
+ assets_url_path,
+ StaticFiles(directory=assets_folder),
+ name=blueprint_name,
+ )
+ except RuntimeError:
+ # directory doesnt exist
+ pass
+
+ def register_error_handlers(self, app):
+ self.error_handling_mode = "prune"
+
+ def _get_traceback(self, _secret, error: Exception):
+ tb = error.__traceback__
+ errors = traceback.format_exception(type(error), error, tb)
+ pass_errs = []
+ callback_handled = False
+ for err in errors:
+ if self.error_handling_mode == "prune":
+ if not callback_handled:
+ if "callback invoked" in str(err) and "_callback.py" in str(err):
+ callback_handled = True
+ continue
+ pass_errs.append(err)
+ formatted_tb = "".join(pass_errs)
+ error_type = type(error).__name__
+ error_msg = str(error)
+
+ # Parse traceback lines to group by file
+ file_cards = []
+ pattern = re.compile(r' File "(.+)", line (\d+), in (\w+)')
+ lines = formatted_tb.split("\n")
+ current_file = None
+ card_lines = []
+
+ for line in lines[:-1]: # Skip the last line (error message)
+ match = pattern.match(line)
+ if match:
+ if current_file and card_lines:
+ file_cards.append((current_file, card_lines))
+ current_file = (
+ f"{match.group(1)} (line {match.group(2)}, in {match.group(3)})"
+ )
+ card_lines = [line]
+ elif current_file:
+ card_lines.append(line)
+ if current_file and card_lines:
+ file_cards.append((current_file, card_lines))
+
+ cards_html = ""
+ for filename, card in file_cards:
+ cards_html += (
+ f"""
+
+
+
"""
+ + "\n".join(card)
+ + """
+
+ """
+ )
+
+ html = f"""
+
+
+
+ {error_type}: {error_msg} // FastAPI Debugger
+
+
+
+
+
{error_type}
+
+
{error_type}: {error_msg}
+
+
Traceback (most recent call last)
+ {cards_html}
+
{error_type}: {error_msg}
+
+
This is the Copy/Paste friendly version of the traceback.
+
+
+
+ The debugger caught an exception in your ASGI application. You can now
+ look at the traceback which led to the error.
+
+
+
+
+
+ """
+ return html
+
+ def register_prune_error_handler(self, _app, _secret, prune_errors):
+ if prune_errors:
+ self.error_handling_mode = "prune"
+ else:
+ self.error_handling_mode = "raise"
+
+ def _html_response_wrapper(self, view_func):
+ async def wrapped(*_args, **_kwargs):
+ # If view_func is a function, call it; if it's a string, use it directly
+ html = view_func() if callable(view_func) else view_func
+ return Response(content=html, media_type="text/html")
+
+ return wrapped
+
+ def setup_index(self, dash_app):
+ async def index(request: Request):
+ adapter = FastAPIRequestAdapter()
+ set_request_adapter(adapter)
+ adapter.set_request(request)
+ return Response(content=dash_app.index(), media_type="text/html")
+
+ # pylint: disable=protected-access
+ dash_app._add_url("", index, methods=["GET"])
+
+ def setup_catchall(self, dash_app):
+ @dash_app.server.on_event("startup")
+ def _setup_catchall():
+ config = load_config()
+ dash_app.enable_dev_tools(**config, first_run=False)
+
+ async def catchall(request: Request):
+ adapter = FastAPIRequestAdapter()
+ set_request_adapter(adapter)
+ adapter.set_request(request)
+ return Response(content=dash_app.index(), media_type="text/html")
+
+ # pylint: disable=protected-access
+ dash_app._add_url("{path:path}", catchall, methods=["GET"])
+
+ def add_url_rule(
+ self, app, rule, view_func, endpoint=None, methods=None, include_in_schema=False
+ ):
+ if rule == "":
+ rule = "/"
+ if isinstance(view_func, str):
+ # Wrap string or sync function to async FastAPI handler
+ view_func = self._html_response_wrapper(view_func)
+ app.add_api_route(
+ rule,
+ view_func,
+ methods=methods or ["GET"],
+ name=endpoint,
+ include_in_schema=include_in_schema,
+ )
+
+ def before_request(self, app, func):
+ # FastAPI does not have before_request, but we can use middleware
+ app.middleware("http")(self._make_before_middleware(func))
+
+ def after_request(self, app, func):
+ # FastAPI does not have after_request, but we can use middleware
+ app.middleware("http")(self._make_after_middleware(func))
+
+ def run(self, dash_app, app, host, port, debug, **kwargs):
+ frame = inspect.stack()[2]
+ config = dict(
+ {"debug": debug} if debug else {},
+ **{
+ f"dev_tools_{k}": v for k, v in dash_app._dev_tools.items()
+ }, # pylint: disable=protected-access
+ )
+ save_config(config)
+ if debug:
+ if kwargs.get("reload") is None:
+ kwargs["reload"] = True
+ if kwargs.get("reload"):
+ # Dynamically determine the module name from the file path
+ file_path = frame.filename
+ module_name = importlib.util.spec_from_file_location("app", file_path).name
+ uvicorn.run(
+ f"{module_name}:app.server",
+ host=host,
+ port=port,
+ **kwargs,
+ )
+ else:
+ uvicorn.run(app, host=host, port=port, **kwargs)
+
+ def make_response(self, data, mimetype=None, content_type=None):
+ headers = {}
+ if mimetype:
+ headers["content-type"] = mimetype
+ if content_type:
+ headers["content-type"] = content_type
+ return Response(content=data, headers=headers)
+
+ def jsonify(self, obj):
+ return JSONResponse(content=obj)
+
+ def get_request_adapter(self):
+ return FastAPIRequestAdapter
+
+ def _make_before_middleware(self, _func):
+ async def middleware(request, call_next):
+ try:
+ response = await call_next(request)
+ return response
+ except PreventUpdate:
+ # No content, nothing to update
+ return Response(status_code=204)
+ except Exception as e:
+ if self.error_handling_mode in ["raise", "prune"]:
+ # Prune the traceback to remove internal Dash calls
+ tb = self._get_traceback(None, e)
+ return Response(content=tb, media_type="text/html", status_code=500)
+ return JSONResponse(
+ status_code=500,
+ content={"error": "InternalServerError", "message": str(e.args[0])},
+ )
+
+ return middleware
+
+ def _make_after_middleware(self, func):
+ async def middleware(request, call_next):
+ response = await call_next(request)
+ if func is not None:
+ if inspect.iscoroutinefunction(func):
+ await func()
+ else:
+ func()
+ return response
+
+ return middleware
+
+ def serve_component_suites(
+ self, dash_app, package_name, fingerprinted_path, request
+ ):
+ path_in_pkg, has_fingerprint = check_fingerprint(fingerprinted_path)
+ _validate.validate_js_path(dash_app.registered_paths, package_name, path_in_pkg)
+ extension = "." + path_in_pkg.split(".")[-1]
+ mimetype = mimetypes.types_map.get(extension, "application/octet-stream")
+ package = sys.modules[package_name]
+ dash_app.logger.debug(
+ "serving -- package: %s[%s] resource: %s => location: %s",
+ package_name,
+ package.__version__,
+ path_in_pkg,
+ package.__path__,
+ )
+ data = pkgutil.get_data(package_name, path_in_pkg)
+ headers = {}
+ if has_fingerprint:
+ headers["Cache-Control"] = "public, max-age=31536000"
+ return StarletteResponse(content=data, media_type=mimetype, headers=headers)
+ etag = hashlib.md5(data).hexdigest() if data else ""
+ headers["ETag"] = etag
+ if request.headers.get("if-none-match") == etag:
+ return StarletteResponse(status_code=304)
+ return StarletteResponse(content=data, media_type=mimetype, headers=headers)
+
+ def setup_component_suites(self, dash_app):
+ async def serve(request: Request, package_name: str, fingerprinted_path: str):
+ return self.serve_component_suites(
+ dash_app, package_name, fingerprinted_path, request
+ )
+
+ # pylint: disable=protected-access
+ dash_app._add_url(
+ "_dash-component-suites/{package_name}/{fingerprinted_path:path}",
+ serve,
+ )
+
+ # pylint: disable=unused-argument
+ def dispatch(self, app, dash_app, use_async=False):
+ async def _dispatch(request: Request):
+ adapter = FastAPIRequestAdapter()
+ set_request_adapter(adapter)
+ adapter.set_request(request)
+ # pylint: disable=protected-access
+ body = await request.json()
+ g = dash_app._initialize_context(
+ body, adapter
+ ) # pylint: disable=protected-access
+ func = dash_app._prepare_callback(
+ g, body
+ ) # pylint: disable=protected-access
+ args = dash_app._inputs_to_vals(
+ g.inputs_list + g.states_list
+ ) # pylint: disable=protected-access
+ ctx = copy_context()
+ partial_func = dash_app._execute_callback(
+ func, args, g.outputs_list, g
+ ) # pylint: disable=protected-access
+ response_data = ctx.run(partial_func)
+ if inspect.iscoroutine(response_data):
+ response_data = await response_data
+ # Instead of set_data, return a new Response
+ return Response(content=response_data, media_type="application/json")
+
+ return _dispatch
+
+ def _serve_default_favicon(self):
+ return Response(
+ content=pkgutil.get_data("dash", "favicon.ico"), media_type="image/x-icon"
+ )
+
+ def register_timing_hooks(self, app, first_run):
+ if not first_run:
+ return
+
+ @app.middleware("http")
+ async def timing_middleware(request, call_next):
+ # Before request
+ request.state.timing_information = {
+ "__dash_server": {"dur": time.time(), "desc": None}
+ }
+ response = await call_next(request)
+ # After request
+ timing_information = getattr(request.state, "timing_information", None)
+ if timing_information is not None:
+ dash_total = timing_information.get("__dash_server", None)
+ if dash_total is not None:
+ dash_total["dur"] = round((time.time() - dash_total["dur"]) * 1000)
+ headers = MutableHeaders(response.headers)
+ for name, info in timing_information.items():
+ value = name
+ if info.get("desc") is not None:
+ value += f';desc="{info["desc"]}"'
+ if info.get("dur") is not None:
+ value += f";dur={info['dur']}"
+ headers.append("Server-Timing", value)
+ return response
+
+ def register_callback_api_routes(self, app, callback_api_paths):
+ """
+ Register callback API endpoints on the FastAPI app.
+ Each key in callback_api_paths is a route, each value is a handler (sync or async).
+ Dynamically creates a Pydantic model for the handler's parameters and uses it as the body parameter.
+ """
+ for path, handler in callback_api_paths.items():
+ endpoint = f"dash_callback_api_{path}"
+ route = path if path.startswith("/") else f"/{path}"
+ methods = ["POST"]
+ sig = inspect.signature(handler)
+ param_names = list(sig.parameters.keys())
+ fields = {name: (Optional[Any], None) for name in param_names}
+ Model = create_model(
+ f"Payload_{endpoint}", **fields
+ ) # pylint: disable=cell-var-from-loop
+
+ # pylint: disable=cell-var-from-loop
+ async def view_func(request: Request, body: Model):
+ kwargs = body.dict(exclude_unset=True)
+ if inspect.iscoroutinefunction(handler):
+ result = await handler(**kwargs)
+ else:
+ result = handler(**kwargs)
+ return JSONResponse(content=result)
+
+ app.add_api_route(
+ route,
+ view_func,
+ methods=methods,
+ name=endpoint,
+ include_in_schema=True,
+ )
+
+
+class FastAPIRequestAdapter:
+ def __init__(self):
+ self._request = None
+
+ def set_request(self, request: Request):
+ self._request = request
+
+ def get_root(self):
+ return str(self._request.base_url)
+
+ def get_args(self):
+ return self._request.query_params
+
+ async def get_json(self):
+ return await self._request.json()
+
+ def is_json(self):
+ return self._request.headers.get("content-type", "").startswith(
+ "application/json"
+ )
+
+ def get_cookies(self, _request=None):
+ return self._request.cookies
+
+ def get_headers(self):
+ return self._request.headers
+
+ def get_full_path(self):
+ return str(self._request.url)
+
+ def get_url(self):
+ return str(self._request.url)
+
+ def get_remote_addr(self):
+ return self._request.client.host if self._request.client else None
+
+ def get_origin(self):
+ return self._request.headers.get("origin")
+
+ def get_path(self):
+ return self._request.url.path
diff --git a/dash/backend/flask.py b/dash/backend/flask.py
new file mode 100644
index 0000000000..cf544ef5bc
--- /dev/null
+++ b/dash/backend/flask.py
@@ -0,0 +1,322 @@
+from contextvars import copy_context
+import asyncio
+import pkgutil
+import sys
+import mimetypes
+import time
+import inspect
+import traceback
+import flask
+from dash.fingerprint import check_fingerprint
+from dash import _validate
+from dash._callback import _invoke_callback, _async_invoke_callback
+from dash.exceptions import PreventUpdate, InvalidResourceError
+from dash.backend import set_request_adapter
+from .base_server import BaseDashServer
+
+
+class FlaskDashServer(BaseDashServer):
+ def __call__(self, server, *args, **kwargs):
+ # Always WSGI
+ return server(*args, **kwargs)
+
+ def create_app(self, name="__main__", config=None):
+ app = flask.Flask(name)
+ if config:
+ app.config.update(config)
+ return app
+
+ def register_assets_blueprint(
+ self, app, blueprint_name, assets_url_path, assets_folder
+ ):
+ bp = flask.Blueprint(
+ blueprint_name,
+ __name__,
+ static_folder=assets_folder,
+ static_url_path=assets_url_path,
+ )
+ app.register_blueprint(bp)
+
+ def register_error_handlers(self, app):
+ @app.errorhandler(PreventUpdate)
+ def _handle_error(_):
+ return "", 204
+
+ @app.errorhandler(InvalidResourceError)
+ def _invalid_resources_handler(err):
+ return err.args[0], 404
+
+ def _get_traceback(self, secret, error: Exception):
+ try:
+ from werkzeug.debug import (
+ tbtools,
+ ) # pylint: disable=import-outside-toplevel
+ except ImportError:
+ tbtools = None
+
+ def _get_skip(error):
+ tb = error.__traceback__
+ skip = 1
+ while tb.tb_next is not None:
+ skip += 1
+ tb = tb.tb_next
+ if tb.tb_frame.f_code in [
+ _invoke_callback.__code__,
+ _async_invoke_callback.__code__,
+ ]:
+ return skip
+ return skip
+
+ def _do_skip(error):
+ tb = error.__traceback__
+ while tb.tb_next is not None:
+ if tb.tb_frame.f_code in [
+ _invoke_callback.__code__,
+ _async_invoke_callback.__code__,
+ ]:
+ return tb.tb_next
+ tb = tb.tb_next
+ return error.__traceback__
+
+ if hasattr(tbtools, "get_current_traceback"):
+ return tbtools.get_current_traceback(skip=_get_skip(error)).render_full()
+ if hasattr(tbtools, "DebugTraceback"):
+ return tbtools.DebugTraceback(
+ error, skip=_get_skip(error)
+ ).render_debugger_html(True, secret, True)
+ return "".join(traceback.format_exception(type(error), error, _do_skip(error)))
+
+ def register_prune_error_handler(self, app, secret, prune_errors):
+ if prune_errors:
+
+ @app.errorhandler(Exception)
+ def _wrap_errors(error):
+ tb = self._get_traceback(secret, error)
+ return tb, 500
+
+ def add_url_rule(self, app, rule, view_func, endpoint=None, methods=None):
+ app.add_url_rule(
+ rule, view_func=view_func, endpoint=endpoint, methods=methods or ["GET"]
+ )
+
+ def before_request(self, app, func):
+ app.before_request(func)
+
+ def after_request(self, app, func):
+ app.after_request(func)
+
+ def run(self, _dash_app, app, host, port, debug, **kwargs):
+ app.run(host=host, port=port, debug=debug, **kwargs)
+
+ def make_response(self, data, mimetype=None, content_type=None):
+ return flask.Response(data, mimetype=mimetype, content_type=content_type)
+
+ def jsonify(self, obj):
+ return flask.jsonify(obj)
+
+ def get_request_adapter(self):
+ return FlaskRequestAdapter
+
+ def setup_catchall(self, dash_app):
+ def catchall(*args, **kwargs):
+ adapter = FlaskRequestAdapter()
+ set_request_adapter(adapter)
+ return dash_app.index(*args, **kwargs)
+
+ # pylint: disable=protected-access
+ dash_app._add_url("", catchall, methods=["GET"])
+
+ def setup_index(self, dash_app):
+ def index(*args, **kwargs):
+ adapter = FlaskRequestAdapter()
+ set_request_adapter(adapter)
+ return dash_app.index(*args, **kwargs)
+
+ # pylint: disable=protected-access
+ dash_app._add_url("", index, methods=["GET"])
+
+ def serve_component_suites(self, dash_app, package_name, fingerprinted_path):
+ path_in_pkg, has_fingerprint = check_fingerprint(fingerprinted_path)
+ _validate.validate_js_path(dash_app.registered_paths, package_name, path_in_pkg)
+ extension = "." + path_in_pkg.split(".")[-1]
+ mimetype = mimetypes.types_map.get(extension, "application/octet-stream")
+ package = sys.modules[package_name]
+ dash_app.logger.debug(
+ "serving -- package: %s[%s] resource: %s => location: %s",
+ package_name,
+ package.__version__,
+ path_in_pkg,
+ package.__path__,
+ )
+ data = pkgutil.get_data(package_name, path_in_pkg)
+ response = flask.Response(data, mimetype=mimetype)
+ if has_fingerprint:
+ response.cache_control.max_age = 31536000 # 1 year
+ else:
+ response.add_etag()
+ tag = response.get_etag()[0]
+ request_etag = flask.request.headers.get("If-None-Match")
+ if f'"{tag}"' == request_etag:
+ response = flask.Response(None, status=304)
+ return response
+
+ def setup_component_suites(self, dash_app):
+ def serve(package_name, fingerprinted_path):
+ return self.serve_component_suites(
+ dash_app, package_name, fingerprinted_path
+ )
+
+ # pylint: disable=protected-access
+ dash_app._add_url(
+ "_dash-component-suites//",
+ serve,
+ )
+
+ # pylint: disable=unused-argument
+ def dispatch(self, app, dash_app, use_async=False):
+ def _dispatch():
+ adapter = FlaskRequestAdapter()
+ set_request_adapter(adapter)
+ body = flask.request.get_json()
+ # pylint: disable=protected-access
+ g = dash_app._initialize_context(body, adapter)
+ func = dash_app._prepare_callback(g, body)
+ args = dash_app._inputs_to_vals(g.inputs_list + g.states_list)
+ ctx = copy_context()
+ partial_func = dash_app._execute_callback(func, args, g.outputs_list, g)
+ response_data = ctx.run(partial_func)
+ if asyncio.iscoroutine(response_data):
+ raise Exception(
+ "You are trying to use a coroutine without dash[async]. "
+ "Please install the dependencies via `pip install dash[async]` and ensure "
+ "that `use_async=False` is not being passed to the app."
+ )
+ g.dash_response.set_data(response_data)
+ return g.dash_response
+
+ async def _dispatch_async():
+ adapter = FlaskRequestAdapter()
+ set_request_adapter(adapter)
+ body = flask.request.get_json()
+ # pylint: disable=protected-access
+ g = dash_app._initialize_context(body, adapter)
+ func = dash_app._prepare_callback(g, body)
+ args = dash_app._inputs_to_vals(g.inputs_list + g.states_list)
+ ctx = copy_context()
+ partial_func = dash_app._execute_callback(func, args, g.outputs_list, g)
+ response_data = ctx.run(partial_func)
+ if asyncio.iscoroutine(response_data):
+ response_data = await response_data
+ g.dash_response.set_data(response_data)
+ return g.dash_response
+
+ if use_async:
+ return _dispatch_async
+ return _dispatch
+
+ def _serve_default_favicon(self):
+
+ return flask.Response(
+ pkgutil.get_data("dash", "favicon.ico"), content_type="image/x-icon"
+ )
+
+ def register_timing_hooks(self, app, _first_run):
+ def _before_request():
+ flask.g.timing_information = {
+ "__dash_server": {"dur": time.time(), "desc": None}
+ }
+
+ def _after_request(response):
+ timing_information = flask.g.get("timing_information", None)
+ if timing_information is None:
+ return response
+ dash_total = timing_information.get("__dash_server", None)
+ if dash_total is not None:
+ dash_total["dur"] = round((time.time() - dash_total["dur"]) * 1000)
+ for name, info in timing_information.items():
+ value = name
+ if info.get("desc") is not None:
+ value += f';desc="{info["desc"]}"'
+ if info.get("dur") is not None:
+ value += f";dur={info['dur']}"
+ response.headers.add("Server-Timing", value)
+ return response
+
+ self.before_request(app, _before_request)
+ self.after_request(app, _after_request)
+
+ def register_callback_api_routes(self, app, callback_api_paths):
+ """
+ Register callback API endpoints on the Flask app.
+ Each key in callback_api_paths is a route, each value is a handler (sync or async).
+ The view function parses the JSON body and passes it to the handler.
+ """
+ for path, handler in callback_api_paths.items():
+ endpoint = f"dash_callback_api_{path}"
+ route = path if path.startswith("/") else f"/{path}"
+ methods = ["POST"]
+
+ if inspect.iscoroutinefunction(handler):
+
+ async def view_func(*args, handler=handler, **kwargs):
+ data = flask.request.get_json()
+ result = await handler(**data) if data else await handler()
+ return flask.jsonify(result)
+
+ else:
+
+ def view_func(*args, handler=handler, **kwargs):
+ data = flask.request.get_json()
+ result = handler(**data) if data else handler()
+ return flask.jsonify(result)
+
+ # Flask 2.x+ supports async views natively
+ app.add_url_rule(
+ route, endpoint=endpoint, view_func=view_func, methods=methods
+ )
+
+
+class FlaskRequestAdapter:
+ @staticmethod
+ def get_args():
+ return flask.request.args
+
+ @staticmethod
+ def get_root():
+ return flask.request.url_root
+
+ @staticmethod
+ def get_json():
+ return flask.request.get_json()
+
+ @staticmethod
+ def is_json():
+ return flask.request.is_json
+
+ @staticmethod
+ def get_cookies():
+ return flask.request.cookies
+
+ @staticmethod
+ def get_headers():
+ return flask.request.headers
+
+ @staticmethod
+ def get_url():
+ return flask.request.url
+
+ @staticmethod
+ def get_full_path():
+ return flask.request.full_path
+
+ @staticmethod
+ def get_remote_addr():
+ return flask.request.remote_addr
+
+ @staticmethod
+ def get_origin():
+ return getattr(flask.request, "origin", None)
+
+ @staticmethod
+ def get_path():
+ return flask.request.path
diff --git a/dash/backend/quart.py b/dash/backend/quart.py
new file mode 100644
index 0000000000..830d7dd3b9
--- /dev/null
+++ b/dash/backend/quart.py
@@ -0,0 +1,414 @@
+import inspect
+import pkgutil
+import mimetypes
+import sys
+import time
+from contextvars import copy_context
+import traceback
+import re
+
+try:
+ import quart
+ from quart import Quart, Response, jsonify, request, Blueprint
+except ImportError:
+ quart = None
+ Quart = None
+ Response = None
+ jsonify = None
+ request = None
+ Blueprint = None
+from dash.exceptions import PreventUpdate, InvalidResourceError
+from dash.backend import set_request_adapter
+from dash.fingerprint import check_fingerprint
+from dash import _validate
+from .base_server import BaseDashServer
+
+
+class QuartDashServer(BaseDashServer):
+ """Quart implementation of the Dash server factory.
+
+ All Quart/async specific imports are at the top-level (per user request) so
+ Quart must be installed when this module is imported.
+ """
+
+ def __init__(self) -> None:
+ self.config = {}
+ self.error_handling_mode = "prune"
+ super().__init__()
+
+ def __call__(self, server, *args, **kwargs):
+ return server(*args, **kwargs)
+
+ def create_app(self, name="__main__", config=None):
+ app = Quart(name)
+ if config:
+ for key, value in config.items():
+ app.config[key] = value
+ return app
+
+ def register_assets_blueprint(
+ self, app, blueprint_name, assets_url_path, assets_folder
+ ):
+ bp = Blueprint(
+ blueprint_name,
+ __name__,
+ static_folder=assets_folder,
+ static_url_path=assets_url_path,
+ )
+ app.register_blueprint(bp)
+
+ def _get_traceback(self, _secret, error: Exception):
+ tb = error.__traceback__
+ errors = traceback.format_exception(type(error), error, tb)
+ pass_errs = []
+ callback_handled = False
+ for err in errors:
+ if self.error_handling_mode == "prune":
+ if not callback_handled:
+ if "callback invoked" in str(err) and "_callback.py" in str(err):
+ callback_handled = True
+ continue
+ pass_errs.append(err)
+ formatted_tb = "".join(pass_errs)
+ error_type = type(error).__name__
+ error_msg = str(error)
+
+ # Parse traceback lines to group by file
+ file_cards = []
+ pattern = re.compile(r' File "(.+)", line (\d+), in (\w+)')
+ lines = formatted_tb.split("\n")
+ current_file = None
+ card_lines = []
+
+ for line in lines[:-1]: # Skip the last line (error message)
+ match = pattern.match(line)
+ if match:
+ if current_file and card_lines:
+ file_cards.append((current_file, card_lines))
+ current_file = (
+ f"{match.group(1)} (line {match.group(2)}, in {match.group(3)})"
+ )
+ card_lines = [line]
+ elif current_file:
+ card_lines.append(line)
+ if current_file and card_lines:
+ file_cards.append((current_file, card_lines))
+
+ cards_html = ""
+ for filename, card in file_cards:
+ cards_html += (
+ f"""
+
+
+
"""
+ + "\n".join(card)
+ + """
+
+ """
+ )
+
+ html = f"""
+
+
+
+ {error_type}: {error_msg} // Quart Debugger
+
+
+
+
+
{error_type}
+
+
{error_type}: {error_msg}
+
+
Traceback (most recent call last)
+ {cards_html}
+
{error_type}: {error_msg}
+
+
This is the Copy/Paste friendly version of the traceback.
+
+
+
+ The debugger caught an exception in your ASGI application. You can now
+ look at the traceback which led to the error.
+
+
+
+
+
+ """
+ return html
+
+ def register_prune_error_handler(self, app, secret, prune_errors):
+ if prune_errors:
+ self.error_handling_mode = "prune"
+ else:
+ self.error_handling_mode = "raise"
+
+ @app.errorhandler(Exception)
+ async def _wrap_errors(error):
+ tb = self._get_traceback(secret, error)
+ return Response(tb, status=500, content_type="text/html")
+
+ def register_timing_hooks(self, app, _first_run): # parity with Flask factory
+ @app.before_request
+ async def _before_request(): # pragma: no cover - timing infra
+ quart.g.timing_information = {
+ "__dash_server": {"dur": time.time(), "desc": None}
+ }
+
+ @app.after_request
+ async def _after_request(response): # pragma: no cover - timing infra
+ timing_information = getattr(quart.g, "timing_information", None)
+ if timing_information is None:
+ return response
+ dash_total = timing_information.get("__dash_server", None)
+ if dash_total is not None:
+ dash_total["dur"] = round((time.time() - dash_total["dur"]) * 1000)
+ for name, info in timing_information.items():
+ value = name
+ if info.get("desc") is not None:
+ value += f';desc="{info["desc"]}"'
+ if info.get("dur") is not None:
+ value += f";dur={info['dur']}"
+ # Quart/Werkzeug headers expose 'add' (not 'append')
+ if hasattr(response.headers, "add"):
+ response.headers.add("Server-Timing", value)
+ else: # fallback just in case
+ response.headers["Server-Timing"] = value
+ return response
+
+ def register_error_handlers(self, app):
+ @app.errorhandler(PreventUpdate)
+ async def _prevent_update(_):
+ return "", 204
+
+ @app.errorhandler(InvalidResourceError)
+ async def _invalid_resource(err):
+ return err.args[0], 404
+
+ def _html_response_wrapper(self, view_func):
+ async def wrapped(*_args, **_kwargs):
+ html_val = view_func() if callable(view_func) else view_func
+ if inspect.iscoroutine(html_val): # handle async function returning html
+ html_val = await html_val
+ html = str(html_val)
+ return Response(html, content_type="text/html")
+
+ return wrapped
+
+ def add_url_rule(self, app, rule, view_func, endpoint=None, methods=None):
+ app.add_url_rule(
+ rule, view_func=view_func, endpoint=endpoint, methods=methods or ["GET"]
+ )
+
+ def setup_index(self, dash_app):
+ async def index(*args, **kwargs):
+ adapter = QuartRequestAdapter()
+ set_request_adapter(adapter)
+ adapter.set_request()
+ return Response(dash_app.index(*args, **kwargs), content_type="text/html")
+
+ # pylint: disable=protected-access
+ dash_app._add_url("", index, methods=["GET"])
+
+ def setup_catchall(self, dash_app):
+ async def catchall(
+ path, *args, **kwargs
+ ): # noqa: ARG001 - path is unused but kept for route signature, pylint: disable=unused-argument
+ adapter = QuartRequestAdapter()
+ set_request_adapter(adapter)
+ adapter.set_request()
+ return Response(dash_app.index(*args, **kwargs), content_type="text/html")
+
+ # pylint: disable=protected-access
+ dash_app._add_url("", catchall, methods=["GET"])
+
+ def before_request(self, app, func):
+ app.before_request(func)
+
+ def after_request(self, app, func):
+ @app.after_request
+ async def _after(response):
+ if func is not None:
+ result = func()
+ if inspect.iscoroutine(result): # Allow async hooks
+ await result
+ return response
+
+ def run(self, _dash_app, app, host, port, debug, **kwargs):
+ self.config = {"debug": debug, **kwargs} if debug else kwargs
+ app.run(host=host, port=port, debug=debug, **kwargs)
+
+ def make_response(self, data, mimetype=None, content_type=None):
+ return Response(data, mimetype=mimetype, content_type=content_type)
+
+ def jsonify(self, obj):
+ return jsonify(obj)
+
+ def get_request_adapter(self):
+ return QuartRequestAdapter
+
+ def serve_component_suites(
+ self, dash_app, package_name, fingerprinted_path
+ ): # noqa: ARG002 unused req preserved for interface parity
+ path_in_pkg, has_fingerprint = check_fingerprint(fingerprinted_path)
+ _validate.validate_js_path(dash_app.registered_paths, package_name, path_in_pkg)
+ extension = "." + path_in_pkg.split(".")[-1]
+ mimetype = mimetypes.types_map.get(extension, "application/octet-stream")
+ package = sys.modules[package_name]
+ dash_app.logger.debug(
+ "serving -- package: %s[%s] resource: %s => location: %s",
+ package_name,
+ getattr(package, "__version__", "unknown"),
+ path_in_pkg,
+ package.__path__,
+ )
+ data = pkgutil.get_data(package_name, path_in_pkg)
+ headers = {}
+ if has_fingerprint:
+ headers["Cache-Control"] = "public, max-age=31536000"
+
+ return Response(data, content_type=mimetype, headers=headers)
+
+ def setup_component_suites(self, dash_app):
+ async def serve(package_name, fingerprinted_path):
+ return self.serve_component_suites(
+ dash_app, package_name, fingerprinted_path
+ )
+
+ # pylint: disable=protected-access
+ dash_app._add_url(
+ "_dash-component-suites//",
+ serve,
+ )
+
+ # pylint: disable=unused-argument
+ def dispatch(self, app, dash_app, use_async=True): # Quart always async
+ async def _dispatch():
+ adapter = QuartRequestAdapter()
+ set_request_adapter(adapter)
+ adapter.set_request()
+ body = await request.get_json()
+ # pylint: disable=protected-access
+ g = dash_app._initialize_context(body, adapter)
+ # pylint: disable=protected-access
+ func = dash_app._prepare_callback(g, body)
+ # pylint: disable=protected-access
+ args = dash_app._inputs_to_vals(g.inputs_list + g.states_list)
+ ctx = copy_context()
+ # pylint: disable=protected-access
+ partial_func = dash_app._execute_callback(func, args, g.outputs_list, g)
+ response_data = ctx.run(partial_func)
+ if inspect.iscoroutine(response_data): # if user callback is async
+ response_data = await response_data
+ return Response(response_data, content_type="application/json")
+
+ return _dispatch
+
+ def register_callback_api_routes(self, app, callback_api_paths):
+ """
+ Register callback API endpoints on the Quart app.
+ Each key in callback_api_paths is a route, each value is a handler (sync or async).
+ The view function parses the JSON body and passes it to the handler.
+ """
+ for path, handler in callback_api_paths.items():
+ endpoint = f"dash_callback_api_{path}"
+ route = path if path.startswith("/") else f"/{path}"
+ methods = ["POST"]
+
+ def _make_view_func(handler):
+ if inspect.iscoroutinefunction(handler):
+
+ async def async_view_func(*args, **kwargs):
+ data = await request.get_json()
+ result = await handler(**data) if data else await handler()
+ return jsonify(result)
+
+ return async_view_func
+
+ async def sync_view_func(*args, **kwargs):
+ data = await request.get_json()
+ result = handler(**data) if data else handler()
+ return jsonify(result)
+
+ return sync_view_func
+
+ view_func = _make_view_func(handler)
+ app.add_url_rule(
+ route, endpoint=endpoint, view_func=view_func, methods=methods
+ )
+
+ def _serve_default_favicon(self):
+ return Response(
+ pkgutil.get_data("dash", "favicon.ico"), content_type="image/x-icon"
+ )
+
+
+class QuartRequestAdapter:
+ def __init__(self) -> None:
+ self._request = None
+
+ def set_request(self) -> None:
+ self._request = request
+
+ # Accessors (instance-based)
+ def get_root(self):
+ return self._request.root_url
+
+ def get_args(self):
+ return self._request.args
+
+ async def get_json(self):
+ return await self._request.get_json()
+
+ def is_json(self):
+ return self._request.is_json
+
+ def get_cookies(self):
+ return self._request.cookies
+
+ def get_headers(self):
+ return self._request.headers
+
+ def get_full_path(self):
+ return self._request.full_path
+
+ def get_url(self):
+ return str(self._request.url)
+
+ def get_remote_addr(self):
+ return self._request.remote_addr
+
+ def get_origin(self):
+ return self._request.headers.get("origin")
+
+ def get_path(self):
+ return self._request.path
diff --git a/dash/backend/registry.py b/dash/backend/registry.py
new file mode 100644
index 0000000000..4aae9fafc5
--- /dev/null
+++ b/dash/backend/registry.py
@@ -0,0 +1,29 @@
+import importlib
+
+_backend_imports = {
+ "flask": ("dash.backend.flask", "FlaskDashServer"),
+ "fastapi": ("dash.backend.fastapi", "FastAPIDashServer"),
+ "quart": ("dash.backend.quart", "QuartDashServer"),
+}
+
+
+def register_backend(name, module_path, class_name):
+ """Register a new backend by name."""
+ _backend_imports[name.lower()] = (module_path, class_name)
+
+
+def get_backend(name):
+ try:
+ module_name, class_name = _backend_imports[name.lower()]
+ module = importlib.import_module(module_name)
+ return getattr(module, class_name)
+ except KeyError as e:
+ raise ValueError(f"Unknown backend: {name}") from e
+ except ImportError as e:
+ raise ImportError(
+ f"Could not import module '{module_name}' for backend '{name}': {e}"
+ ) from e
+ except AttributeError as e:
+ raise AttributeError(
+ f"Module '{module_name}' does not have class '{class_name}' for backend '{name}': {e}"
+ ) from e
diff --git a/dash/dash-renderer/src/components/error/FrontEnd/FrontEndError.react.js b/dash/dash-renderer/src/components/error/FrontEnd/FrontEndError.react.js
index 176cb2c6f8..db4c6ddd2b 100644
--- a/dash/dash-renderer/src/components/error/FrontEnd/FrontEndError.react.js
+++ b/dash/dash-renderer/src/components/error/FrontEnd/FrontEndError.react.js
@@ -121,13 +121,18 @@ function BackendError({error, base}) {
const MAX_MESSAGE_LENGTH = 40;
/* eslint-disable no-inline-comments */
function UnconnectedErrorContent({error, base}) {
+ // Helper to detect full HTML document
+ const isFullHtmlDoc =
+ typeof error.html === 'string' &&
+ error.html.trim().toLowerCase().startsWith('
- {/*
- * 40 is a rough heuristic - if longer than 40 then the
- * message might overflow into ellipses in the title above &
- * will need to be displayed in full in this error body
- */}
+ {/* Frontend error message */}
{typeof error.message !== 'string' ||
error.message.length < MAX_MESSAGE_LENGTH ? null : (
@@ -137,6 +142,7 @@ function UnconnectedErrorContent({error, base}) {
)}
+ {/* Frontend stack trace */}
{typeof error.stack !== 'string' ? null : (
@@ -149,7 +155,6 @@ function UnconnectedErrorContent({error, base}) {
browser's console.)
-
{error.stack.split('\n').map((line, i) => (
{line}
))}
@@ -157,24 +162,30 @@ function UnconnectedErrorContent({error, base}) {
)}
- {/* Backend Error */}
- {typeof error.html !== 'string' ? null : error.html
- .substring(0, '
- {/* Embed werkzeug debugger in an iframe to prevent
- CSS leaking - werkzeug HTML includes a bunch
- of CSS on base html elements like `
`
- */}
- ) : (
+ ) : isHtmlFragment ? (
+ // Backend error: HTML fragment
+
+ ) : typeof error.html === 'string' ? (
+ // Backend error: plain text
- )}
+ ) : null}
);
}
diff --git a/dash/dash.py b/dash/dash.py
index 8430259c27..6bba3aadfd 100644
--- a/dash/dash.py
+++ b/dash/dash.py
@@ -4,7 +4,6 @@
import collections
import importlib
import warnings
-from contextvars import copy_context
from importlib.machinery import ModuleSpec
from importlib.util import find_spec
from importlib import metadata
@@ -12,11 +11,9 @@
import threading
import re
import logging
-import time
import mimetypes
import hashlib
import base64
-import traceback
from urllib.parse import urlparse
from typing import Any, Callable, Dict, Optional, Union, Sequence, Literal, List
@@ -29,7 +26,7 @@
from dash import html
from dash import dash_table
-from .fingerprint import build_fingerprint, check_fingerprint
+from .fingerprint import build_fingerprint
from .resources import Scripts, Css
from .dependencies import (
Input,
@@ -38,11 +35,10 @@
)
from .development.base_component import ComponentRegistry
from .exceptions import (
- PreventUpdate,
- InvalidResourceError,
ProxyError,
DuplicateCallback,
)
+from .backend import get_request_adapter, get_backend
from .version import __version__
from ._configs import get_combined_config, pathname_configs, pages_folder_config
from ._utils import (
@@ -68,9 +64,10 @@
from . import _watch
from . import _get_app
-from ._get_app import with_app_context, with_app_context_async, with_app_context_factory
+from ._get_app import with_app_context, with_app_context_factory
from ._grouping import map_grouping, grouping_len, update_args_group
from ._obsolete import ObsoleteChecker
+from ._callback_context import callback_context
from . import _pages
from ._pages import (
@@ -157,61 +154,34 @@
page_container = None
-def _get_traceback(secret, error: Exception):
+def _is_flask_instance(obj):
try:
# pylint: disable=import-outside-toplevel
- from werkzeug.debug import tbtools
+ from flask import Flask
+
+ return isinstance(obj, Flask)
except ImportError:
- tbtools = None
+ return False
- def _get_skip(error):
- from dash._callback import ( # pylint: disable=import-outside-toplevel
- _invoke_callback,
- _async_invoke_callback,
- )
- tb = error.__traceback__
- skip = 1
- while tb.tb_next is not None:
- skip += 1
- tb = tb.tb_next
- if tb.tb_frame.f_code in [
- _invoke_callback.__code__,
- _async_invoke_callback.__code__,
- ]:
- return skip
-
- return skip
-
- def _do_skip(error):
- from dash._callback import ( # pylint: disable=import-outside-toplevel
- _invoke_callback,
- _async_invoke_callback,
- )
+def _is_fastapi_instance(obj):
+ try:
+ # pylint: disable=import-outside-toplevel
+ from fastapi import FastAPI
- tb = error.__traceback__
- while tb.tb_next is not None:
- if tb.tb_frame.f_code in [
- _invoke_callback.__code__,
- _async_invoke_callback.__code__,
- ]:
- return tb.tb_next
- tb = tb.tb_next
- return error.__traceback__
+ return isinstance(obj, FastAPI)
+ except ImportError:
+ return False
- # werkzeug<2.1.0
- if hasattr(tbtools, "get_current_traceback"):
- return tbtools.get_current_traceback( # type: ignore
- skip=_get_skip(error)
- ).render_full()
- if hasattr(tbtools, "DebugTraceback"):
- # pylint: disable=no-member
- return tbtools.DebugTraceback( # type: ignore
- error, skip=_get_skip(error)
- ).render_debugger_html(True, secret, True)
+def _is_quart_instance(obj):
+ try:
+ # pylint: disable=import-outside-toplevel
+ from quart import Quart
- return "".join(traceback.format_exception(type(error), error, _do_skip(error)))
+ return isinstance(obj, Quart)
+ except ImportError:
+ return False
# Singleton signal to not update an output, alternative to PreventUpdate
@@ -249,6 +219,12 @@ class Dash(ObsoleteChecker):
``flask.Flask``: use this pre-existing Flask server.
:type server: boolean or flask.Flask
+ :param backend: The backend to use for the Dash app. Can be a string
+ (name of the backend) or a backend class. Default is None, which
+ selects the Flask backend. Currently, "flask", "fastapi", and "quart" backends
+ are supported.
+ :type backend: string or type
+
:param assets_folder: a path, relative to the current working directory,
for extra files to be used in the browser. Default ``'assets'``.
All .js and .css files will be loaded immediately unless excluded by
@@ -421,16 +397,17 @@ class Dash(ObsoleteChecker):
_plotlyjs_url: str
STARTUP_ROUTES: list = []
- server: flask.Flask
+ server: Any
# Layout is a complex type which can be many things
_layout: Any
_extra_components: Any
- def __init__( # pylint: disable=too-many-statements
+ def __init__( # pylint: disable=too-many-statements, too-many-branches
self,
name: Optional[str] = None,
- server: Union[bool, flask.Flask] = True,
+ server: Union[bool, Callable[[], Any]] = True,
+ backend: Union[str, type, None] = None,
assets_folder: str = "assets",
pages_folder: str = "pages",
use_pages: Optional[bool] = None,
@@ -488,16 +465,55 @@ def __init__( # pylint: disable=too-many-statements
caller_name: str = name if name is not None else get_caller_name()
- # We have 3 cases: server is either True (we create the server), False
- # (defer server creation) or a Flask app instance (we use their server)
- if isinstance(server, flask.Flask):
- self.server = server
+ # Determine backend
+ if backend is None:
+ backend_cls = get_backend("flask")
+ elif isinstance(backend, str):
+ backend_cls = get_backend(backend)
+ elif isinstance(backend, type):
+ backend_cls = backend
+ else:
+ raise ValueError("Invalid backend argument")
+
+ # Determine server and backend instance
+ if server not in (None, True, False):
+ # User provided a server instance (e.g., Flask, Quart, FastAPI)
+ if _is_flask_instance(server):
+ inferred_backend = "flask"
+ elif _is_quart_instance(server):
+ inferred_backend = "quart"
+ elif _is_fastapi_instance(server):
+ inferred_backend = "fastapi"
+ else:
+ raise ValueError("Unsupported server type")
+ # Validate that backend matches server type if both are provided
+ if backend is not None:
+ if isinstance(backend, type):
+ # get_backend returns the backend class for a string
+ # So we compare the class names
+ expected_backend_cls = get_backend(inferred_backend)
+ if (
+ backend.__module__ != expected_backend_cls.__module__
+ or backend.__name__ != expected_backend_cls.__name__
+ ):
+ raise ValueError(
+ f"Conflict between provided backend '{backend.__name__}' and server type '{inferred_backend}'."
+ )
+ elif not isinstance(backend, str):
+ raise ValueError("Invalid backend argument")
+ elif backend.lower() != inferred_backend:
+ raise ValueError(
+ f"Conflict between provided backend '{backend}' and server type '{inferred_backend}'."
+ )
+ backend_cls = get_backend(inferred_backend)
if name is None:
caller_name = getattr(server, "name", caller_name)
- elif isinstance(server, bool):
- self.server = flask.Flask(caller_name) if server else None # type: ignore
+ self.backend = backend_cls()
+ self.server = server
else:
- raise ValueError("server must be a Flask app or a boolean")
+ # No server instance provided, create backend and let backend create server
+ self.backend = backend_cls()
+ self.server = self.backend.create_app(caller_name) # type: ignore
base_prefix, routes_prefix, requests_prefix = pathname_configs(
url_base_pathname, routes_pathname_prefix, requests_pathname_prefix
@@ -671,11 +687,15 @@ def _setup_hooks(self):
if self._hooks.get_hooks("error"):
self._on_error = self._hooks.HookErrorHandler(self._on_error)
- def init_app(self, app: Optional[flask.Flask] = None, **kwargs) -> None:
- """Initialize the parts of Dash that require a flask app."""
-
+ def init_app(self, app: Optional[Any] = None, **kwargs) -> None:
config = self.config
-
+ config.unset_read_only(
+ [
+ "url_base_pathname",
+ "routes_pathname_prefix",
+ "requests_pathname_prefix",
+ ]
+ )
config.update(kwargs)
config.set_read_only(
[
@@ -685,89 +705,67 @@ def init_app(self, app: Optional[flask.Flask] = None, **kwargs) -> None:
],
"Read-only: can only be set in the Dash constructor or during init_app()",
)
-
if app is not None:
self.server = app
-
bp_prefix = config.routes_pathname_prefix.replace("/", "_").replace(".", "_")
assets_blueprint_name = f"{bp_prefix}dash_assets"
-
- self.server.register_blueprint(
- flask.Blueprint(
- assets_blueprint_name,
- config.name,
- static_folder=self.config.assets_folder,
- static_url_path=config.routes_pathname_prefix
- + self.config.assets_url_path.lstrip("/"),
- )
+ self.backend.register_assets_blueprint(
+ self.server,
+ assets_blueprint_name,
+ config.routes_pathname_prefix + self.config.assets_url_path.lstrip("/"),
+ self.config.assets_folder,
)
-
if config.compress:
try:
- # pylint: disable=import-outside-toplevel
- from flask_compress import Compress # type: ignore[reportMissingImports]
+ import flask_compress # pylint: disable=import-outside-toplevel
- # gzip
+ Compress = flask_compress.Compress
Compress(self.server)
-
_flask_compress_version = parse_version(
_get_distribution_version("flask_compress")
)
-
if not hasattr(
self.server.config, "COMPRESS_ALGORITHM"
) and _flask_compress_version >= parse_version("1.6.0"):
- # flask-compress==1.6.0 changed default to ['br', 'gzip']
- # and non-overridable default compression with Brotli is
- # causing performance issues
self.server.config["COMPRESS_ALGORITHM"] = ["gzip"]
except ImportError as error:
raise ImportError(
"To use the compress option, you need to install dash[compress]"
) from error
-
- @self.server.errorhandler(PreventUpdate)
- def _handle_error(_):
- """Handle a halted callback and return an empty 204 response."""
- return "", 204
-
- self.server.before_request(self._setup_server)
-
- # add a handler for components suites errors to return 404
- self.server.errorhandler(InvalidResourceError)(self._invalid_resources_handler)
-
+ self.backend.register_error_handlers(self.server)
+ self.backend.before_request(self.server, self._setup_server)
self._setup_routes()
-
_get_app.APP = self
self.enable_pages()
-
self._setup_plotlyjs()
def _add_url(self, name: str, view_func: RouteCallable, methods=("GET",)) -> None:
full_name = self.config.routes_pathname_prefix + name
-
- self.server.add_url_rule(
- full_name, view_func=view_func, endpoint=full_name, methods=list(methods)
+ self.backend.add_url_rule(
+ self.server,
+ full_name,
+ view_func=view_func,
+ endpoint=full_name,
+ methods=list(methods),
)
-
- # record the url in Dash.routes so that it can be accessed later
- # e.g. for adding authentication with flask_login
self.routes.append(full_name)
def _setup_routes(self):
- self._add_url(
- "_dash-component-suites//",
- self.serve_component_suites,
- )
+ self.backend.setup_component_suites(self)
self._add_url("_dash-layout", self.serve_layout)
self._add_url("_dash-dependencies", self.dependencies)
- if self._use_async:
- self._add_url("_dash-update-component", self.async_dispatch, ["POST"])
- else:
- self._add_url("_dash-update-component", self.dispatch, ["POST"])
+ self._add_url(
+ "_dash-update-component",
+ self.backend.dispatch(self.server, self, self._use_async),
+ ["POST"],
+ )
self._add_url("_reload-hash", self.serve_reload_hash)
- self._add_url("_favicon.ico", self._serve_default_favicon)
- self._add_url("", self.index)
+ self._add_url(
+ "_favicon.ico",
+ self.backend._serve_default_favicon, # pylint: disable=protected-access
+ )
+ self.backend.setup_index(self)
+ self.backend.setup_catchall(self)
if jupyter_dash.active:
self._add_url(
@@ -781,9 +779,6 @@ def _setup_routes(self):
hook.data["methods"],
)
- # catch-all for front-end routes, used by dcc.Location
- self._add_url("", self.index)
-
def setup_apis(self):
"""
Register API endpoints for all callbacks defined using `dash.callback`.
@@ -807,30 +802,8 @@ def setup_apis(self):
)
self.callback_api_paths[k] = _callback.GLOBAL_API_PATHS.pop(k)
- def make_parse_body(func):
- def _parse_body():
- if flask.request.is_json:
- data = flask.request.get_json()
- return flask.jsonify(func(**data))
- return flask.jsonify({})
-
- return _parse_body
-
- def make_parse_body_async(func):
- async def _parse_body_async():
- if flask.request.is_json:
- data = flask.request.get_json()
- result = await func(**data)
- return flask.jsonify(result)
- return flask.jsonify({})
-
- return _parse_body_async
-
- for path, func in self.callback_api_paths.items():
- if asyncio.iscoroutinefunction(func):
- self._add_url(path, make_parse_body_async(func), ["POST"])
- else:
- self._add_url(path, make_parse_body(func), ["POST"])
+ # Delegate to the server factory for route registration
+ self.backend.register_callback_api_routes(self.server, self.callback_api_paths)
def _setup_plotlyjs(self):
# pylint: disable=import-outside-toplevel
@@ -902,7 +875,7 @@ def serve_layout(self):
layout = hook(layout)
# TODO - Set browser cache limit - pass hash into frontend
- return flask.Response(
+ return self.backend.make_response(
to_json(layout),
mimetype="application/json",
)
@@ -966,7 +939,7 @@ def serve_reload_hash(self):
_reload.hard = False
_reload.changed_assets = []
- return flask.jsonify(
+ return self.backend.jsonify(
{
"reloadHash": _hash,
"hard": hard,
@@ -1159,58 +1132,21 @@ def _generate_meta(self):
return meta_tags + self.config.meta_tags
- # Serve the JS bundles for each package
- def serve_component_suites(self, package_name, fingerprinted_path):
- path_in_pkg, has_fingerprint = check_fingerprint(fingerprinted_path)
-
- _validate.validate_js_path(self.registered_paths, package_name, path_in_pkg)
-
- extension = "." + path_in_pkg.split(".")[-1]
- mimetype = mimetypes.types_map.get(extension, "application/octet-stream")
-
- package = sys.modules[package_name]
- self.logger.debug(
- "serving -- package: %s[%s] resource: %s => location: %s",
- package_name,
- package.__version__,
- path_in_pkg,
- package.__path__,
- )
-
- response = flask.Response(
- pkgutil.get_data(package_name, path_in_pkg), mimetype=mimetype
- )
-
- if has_fingerprint:
- # Fingerprinted resources are good forever (1 year)
- # No need for ETag as the fingerprint changes with each build
- response.cache_control.max_age = 31536000 # 1 year
- else:
- # Non-fingerprinted resources are given an ETag that
- # will be used / check on future requests
- response.add_etag()
- tag = response.get_etag()[0]
-
- request_etag = flask.request.headers.get("If-None-Match")
-
- if f'"{tag}"' == request_etag:
- response = flask.Response(None, status=304)
-
- return response
-
- @with_app_context
- def index(self, *args, **kwargs): # pylint: disable=unused-argument
+ def index(self, *_args, **_kwargs):
scripts = self._generate_scripts_html()
css = self._generate_css_dist_html()
config = self._generate_config_html()
metas = self._generate_meta()
renderer = self._generate_renderer()
-
- # use self.title instead of app.config.title for backwards compatibility
title = self.title
+ try:
+ request = get_request_adapter()
+ except LookupError:
+ # no request context
+ request = None
- if self.use_pages and self.config.include_pages_meta:
- metas = _page_meta_tags(self) + metas
+ if self.use_pages and self.config.include_pages_meta and request:
+ metas = _page_meta_tags(self, request) + metas
if self._favicon:
favicon_mod_time = os.path.getmtime(
@@ -1314,7 +1250,7 @@ def interpolate_index(self, **kwargs):
@with_app_context
def dependencies(self):
- return flask.Response(
+ return self.backend.make_response(
to_json(self._callback_list),
content_type="application/json",
)
@@ -1417,8 +1353,11 @@ def callback(self, *_args, **_kwargs) -> Callable[..., Any]:
**_kwargs,
)
+ def _inputs_to_vals(self, inputs):
+ return inputs_to_vals(inputs)
+
# pylint: disable=R0915
- def _initialize_context(self, body):
+ def _initialize_context(self, body, adapter):
"""Initialize the global context for the request."""
g = AttributeDict({})
g.inputs_list = body.get("inputs", [])
@@ -1430,12 +1369,15 @@ def _initialize_context(self, body):
{"prop_id": x, "value": g.input_values.get(x)}
for x in body.get("changedPropIds", [])
]
- g.dash_response = flask.Response(mimetype="application/json")
- g.cookies = dict(**flask.request.cookies)
- g.headers = dict(**flask.request.headers)
- g.path = flask.request.full_path
- g.remote = flask.request.remote_addr
- g.origin = flask.request.origin
+ g.dash_response = self.backend.make_response(
+ mimetype="application/json", data=None
+ )
+ g.cookies = dict(adapter.get_cookies())
+ g.headers = dict(adapter.get_headers())
+ g.args = adapter.get_args()
+ g.path = adapter.get_full_path()
+ g.remote = adapter.get_remote_addr()
+ g.origin = adapter.get_origin()
g.updated_props = {}
return g
@@ -1499,11 +1441,6 @@ def _prepare_grouping(self, data_list, indices):
def _execute_callback(self, func, args, outputs_list, g):
"""Execute the callback with the prepared arguments."""
- g.cookies = dict(**flask.request.cookies)
- g.headers = dict(**flask.request.headers)
- g.path = flask.request.full_path
- g.remote = flask.request.remote_addr
- g.origin = flask.request.origin
g.custom_data = AttributeDict({})
for hook in self._hooks.get_hooks("custom_data"):
@@ -1522,47 +1459,6 @@ def _execute_callback(self, func, args, outputs_list, g):
)
return partial_func
- @with_app_context_async
- async def async_dispatch(self):
- body = flask.request.get_json()
- g = self._initialize_context(body)
- func = self._prepare_callback(g, body)
- args = inputs_to_vals(g.inputs_list + g.states_list)
-
- ctx = copy_context()
- partial_func = self._execute_callback(func, args, g.outputs_list, g)
- if asyncio.iscoroutine(func):
- response_data = await ctx.run(partial_func)
- else:
- response_data = ctx.run(partial_func)
-
- if asyncio.iscoroutine(response_data):
- response_data = await response_data
-
- g.dash_response.set_data(response_data)
- return g.dash_response
-
- @with_app_context
- def dispatch(self):
- body = flask.request.get_json()
- g = self._initialize_context(body)
- func = self._prepare_callback(g, body)
- args = inputs_to_vals(g.inputs_list + g.states_list)
-
- ctx = copy_context()
- partial_func = self._execute_callback(func, args, g.outputs_list, g)
- response_data = ctx.run(partial_func)
-
- if asyncio.iscoroutine(response_data):
- raise Exception(
- "You are trying to use a coroutine without dash[async]. "
- "Please install the dependencies via `pip install dash[async]` and ensure "
- "that `use_async=False` is not being passed to the app."
- )
-
- g.dash_response.set_data(response_data)
- return g.dash_response
-
def _setup_server(self):
if self._got_first_request["setup_server"]:
return
@@ -1626,7 +1522,7 @@ def _setup_server(self):
manager=manager,
)
def cancel_call(*_):
- job_ids = flask.request.args.getlist("cancelJob")
+ job_ids = callback_context.args.getlist("cancelJob")
executor = _callback.context_value.get().background_callback_manager
if job_ids:
for job_id in job_ids:
@@ -1695,12 +1591,6 @@ def _walk_assets_directory(self):
def _invalid_resources_handler(err):
return err.args[0], 404
- @staticmethod
- def _serve_default_favicon():
- return flask.Response(
- pkgutil.get_data("dash", "favicon.ico"), content_type="image/x-icon"
- )
-
def csp_hashes(self, hash_algorithm="sha256") -> Sequence[str]:
"""Calculates CSP hashes (sha + base64) of all inline scripts, such that
one of the biggest benefits of CSP (disallowing general inline scripts)
@@ -1947,6 +1837,7 @@ def enable_dev_tools(
dev_tools_silence_routes_logging: Optional[bool] = None,
dev_tools_disable_version_check: Optional[bool] = None,
dev_tools_prune_errors: Optional[bool] = None,
+ first_run: bool = True,
) -> bool:
"""Activate the dev tools, called by `run`. If your application
is served by wsgi and you want to activate the dev tools, you can call
@@ -2109,49 +2000,13 @@ def enable_dev_tools(
jupyter_dash.configure_callback_exception_handling(
self, dev_tools.prune_errors
)
- elif dev_tools.prune_errors:
- secret = gen_salt(20)
-
- @self.server.errorhandler(Exception)
- def _wrap_errors(error):
- # find the callback invocation, if the error is from a callback
- # and skip the traceback up to that point
- # if the error didn't come from inside a callback, we won't
- # skip anything.
- tb = _get_traceback(secret, error)
- return tb, 500
+ secret = gen_salt(20)
+ self.backend.register_prune_error_handler(
+ self.server, secret, dev_tools.prune_errors
+ )
if debug and dev_tools.ui:
-
- def _before_request():
- flask.g.timing_information = { # pylint: disable=assigning-non-slot
- "__dash_server": {"dur": time.time(), "desc": None}
- }
-
- def _after_request(response):
- timing_information = flask.g.get("timing_information", None)
- if timing_information is None:
- return response
-
- dash_total = timing_information.get("__dash_server", None)
- if dash_total is not None:
- dash_total["dur"] = round((time.time() - dash_total["dur"]) * 1000)
-
- for name, info in timing_information.items():
- value = name
- if info.get("desc") is not None:
- value += f';desc="{info["desc"]}"'
-
- if info.get("dur") is not None:
- value += f";dur={info['dur']}"
-
- response.headers.add("Server-Timing", value)
-
- return response
-
- self.server.before_request(_before_request)
-
- self.server.after_request(_after_request)
+ self.backend.register_timing_hooks(self.server, first_run)
if (
debug
@@ -2435,7 +2290,14 @@ def verify_url_part(served_part, url_part, part_name):
server_url=jupyter_server_url,
)
else:
- self.server.run(host=host, port=port, debug=debug, **flask_run_options)
+ self.backend.run(
+ self,
+ self.server,
+ host=host,
+ port=port,
+ debug=debug,
+ **flask_run_options,
+ )
def enable_pages(self) -> None:
if not self.use_pages:
@@ -2443,7 +2305,6 @@ def enable_pages(self) -> None:
if self.pages_folder:
_import_layouts_from_pages(self.config.pages_folder)
- @self.server.before_request
def router():
if self._got_first_request["pages"]:
return
@@ -2495,7 +2356,7 @@ async def update(pathname_, search_, **states):
)
if callable(title):
title = await execute_async_function(
- title, **(path_variables or {})
+ title, **{**(path_variables or {})}
)
return layout, {"title": title}
@@ -2599,10 +2460,7 @@ def update(pathname_, search_, **states):
Input(_ID_STORE, "data"),
)
- def __call__(self, environ, start_response):
- """
- This method makes instances of Dash WSGI-compliant callables.
- It delegates the actual WSGI handling to the internal Flask app's
- __call__ method.
- """
- return self.server(environ, start_response)
+ self.backend.before_request(self.server, router)
+
+ def __call__(self, *args, **kwargs):
+ return self.backend.__call__(self.server, *args, **kwargs)
diff --git a/dash/testing/application_runners.py b/dash/testing/application_runners.py
index dc88afe844..2956f1a4c0 100644
--- a/dash/testing/application_runners.py
+++ b/dash/testing/application_runners.py
@@ -171,7 +171,13 @@ def run():
self.port = options["port"]
try:
- app.run(threaded=True, **options)
+ module = app.server.__class__.__module__
+ # FastAPI support
+ if not module.startswith("flask"):
+ app.run(**options)
+ # Dash/Flask fallback
+ else:
+ app.run(threaded=True, **options)
except SystemExit:
logger.info("Server stopped")
except Exception as error:
@@ -229,7 +235,13 @@ def target():
options = kwargs.copy()
try:
- app.run(threaded=True, **options)
+ module = app.server.__class__.__module__
+ # FastAPI support
+ if not module.startswith("flask"):
+ app.run(**options)
+ # Dash/Flask fallback
+ else:
+ app.run(threaded=True, **options)
except SystemExit:
logger.info("Server stopped")
raise
diff --git a/package.json b/package.json
index e78e279c1b..b7416dbb34 100644
--- a/package.json
+++ b/package.json
@@ -44,7 +44,7 @@
"setup-tests.R": "run-s private::test.R.deploy-*",
"citest.integration": "run-s setup-tests.py private::test.integration-*",
"citest.unit": "run-s private::test.unit-**",
- "test": "pytest && cd dash/dash-renderer && npm run test",
+ "test": "pytest --ignore=tests/backend_tests && cd dash/dash-renderer && npm run test",
"first-build": "cd dash/dash-renderer && npm i && cd ../../ && cd components/dash-html-components && npm i && npm run extract && cd ../../ && npm run build"
},
"devDependencies": {
diff --git a/requirements/fastapi.txt b/requirements/fastapi.txt
new file mode 100644
index 0000000000..97dc7cd8c1
--- /dev/null
+++ b/requirements/fastapi.txt
@@ -0,0 +1,2 @@
+fastapi
+uvicorn
diff --git a/requirements/quart.txt b/requirements/quart.txt
new file mode 100644
index 0000000000..60af440c9c
--- /dev/null
+++ b/requirements/quart.txt
@@ -0,0 +1 @@
+quart
diff --git a/setup.py b/setup.py
index 7ed781c20d..950bcbe14d 100644
--- a/setup.py
+++ b/setup.py
@@ -35,7 +35,9 @@ def read_req_file(req_type):
"testing": read_req_file("testing"),
"celery": read_req_file("celery"),
"diskcache": read_req_file("diskcache"),
- "compress": read_req_file("compress")
+ "compress": read_req_file("compress"),
+ "fastapi": read_req_file("fastapi"),
+ "quart": read_req_file("quart"),
},
entry_points={
"console_scripts": [
diff --git a/tests/backend_tests/__init__.py b/tests/backend_tests/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/tests/backend_tests/test_preconfig_backends.py b/tests/backend_tests/test_preconfig_backends.py
new file mode 100644
index 0000000000..5fbd28dfd9
--- /dev/null
+++ b/tests/backend_tests/test_preconfig_backends.py
@@ -0,0 +1,217 @@
+import pytest
+from dash import Dash, Input, Output, html, dcc
+
+
+@pytest.mark.parametrize(
+ "backend,fixture,input_value",
+ [
+ ("fastapi", "dash_duo", "Hello FastAPI!"),
+ ("quart", "dash_duo_mp", "Hello Quart!"),
+ ],
+)
+def test_backend_basic_callback(request, backend, fixture, input_value):
+ dash_duo = request.getfixturevalue(fixture)
+ if backend == "fastapi":
+ from fastapi import FastAPI
+
+ server = FastAPI()
+ else:
+ import quart
+
+ server = quart.Quart(__name__)
+ app = Dash(__name__, server=server)
+ app.layout = html.Div(
+ [dcc.Input(id="input", value=input_value, type="text"), html.Div(id="output")]
+ )
+
+ @app.callback(Output("output", "children"), Input("input", "value"))
+ def update_output(value):
+ return f"You typed: {value}"
+
+ dash_duo.start_server(app)
+ dash_duo.wait_for_text_to_equal("#output", f"You typed: {input_value}")
+ dash_duo.find_element("#input").clear()
+ dash_duo.find_element("#input").send_keys(f"{backend.title()} Test")
+ dash_duo.wait_for_text_to_equal("#output", f"You typed: {backend.title()} Test")
+ assert dash_duo.get_logs() == []
+
+
+@pytest.mark.parametrize(
+ "backend,fixture,start_server_kwargs",
+ [
+ (
+ "fastapi",
+ "dash_duo",
+ {"debug": True, "reload": False, "dev_tools_ui": True},
+ ),
+ (
+ "quart",
+ "dash_duo_mp",
+ {
+ "debug": True,
+ "use_reloader": False,
+ "dev_tools_hot_reload": False,
+ },
+ ),
+ ],
+)
+def test_backend_error_handling(request, backend, fixture, start_server_kwargs):
+ dash_duo = request.getfixturevalue(fixture)
+ app = Dash(__name__, backend=backend)
+ app.layout = html.Div(
+ [html.Button(id="btn", children="Error", n_clicks=0), html.Div(id="output")]
+ )
+
+ @app.callback(Output("output", "children"), Input("btn", "n_clicks"))
+ def error_callback(n):
+ if n and n > 0:
+ return 1 / 0 # Intentional error
+ return "No error"
+
+ dash_duo.start_server(app, **start_server_kwargs)
+ dash_duo.wait_for_text_to_equal("#output", "No error")
+ dash_duo.find_element("#btn").click()
+ dash_duo.wait_for_text_to_equal(dash_duo.devtools_error_count_locator, "1")
+
+
+def get_error_html(dash_duo, index):
+ # error is in an iframe so is annoying to read out - get it from the store
+ return dash_duo.driver.execute_script(
+ "return store.getState().error.backEnd[{}].error.html;".format(index)
+ )
+
+
+@pytest.mark.parametrize(
+ "backend,fixture,start_server_kwargs, error_msg",
+ [
+ (
+ "fastapi",
+ "dash_duo",
+ {
+ "debug": True,
+ "dev_tools_ui": True,
+ "dev_tools_prune_errors": False,
+ "reload": False,
+ },
+ "fastapi.py",
+ ),
+ (
+ "quart",
+ "dash_duo_mp",
+ {
+ "debug": True,
+ "use_reloader": False,
+ "dev_tools_hot_reload": False,
+ "dev_tools_prune_errors": False,
+ },
+ "quart.py",
+ ),
+ ],
+)
+def test_backend_error_handling_no_prune(
+ request, backend, fixture, start_server_kwargs, error_msg
+):
+ dash_duo = request.getfixturevalue(fixture)
+ app = Dash(__name__, backend=backend)
+ app.layout = html.Div(
+ [html.Button(id="btn", children="Error", n_clicks=0), html.Div(id="output")]
+ )
+
+ @app.callback(Output("output", "children"), Input("btn", "n_clicks"))
+ def error_callback(n):
+ if n and n > 0:
+ return 1 / 0 # Intentional error
+ return "No error"
+
+ dash_duo.start_server(app, **start_server_kwargs)
+ dash_duo.wait_for_text_to_equal("#output", "No error")
+ dash_duo.find_element("#btn").click()
+ dash_duo.wait_for_text_to_equal(dash_duo.devtools_error_count_locator, "1")
+
+ error0 = get_error_html(dash_duo, 0)
+ assert "in error_callback" in error0
+ assert "ZeroDivisionError" in error0
+ assert "backend" in error0 and error_msg in error0
+
+
+@pytest.mark.parametrize(
+ "backend,fixture,start_server_kwargs, error_msg",
+ [
+ ("fastapi", "dash_duo", {"debug": True, "reload": False}, "fastapi.py"),
+ (
+ "quart",
+ "dash_duo_mp",
+ {
+ "debug": True,
+ "use_reloader": False,
+ "dev_tools_hot_reload": False,
+ },
+ "quart.py",
+ ),
+ ],
+)
+def test_backend_error_handling_prune(
+ request, backend, fixture, start_server_kwargs, error_msg
+):
+ dash_duo = request.getfixturevalue(fixture)
+ app = Dash(__name__, backend=backend)
+ app.layout = html.Div(
+ [html.Button(id="btn", children="Error", n_clicks=0), html.Div(id="output")]
+ )
+
+ @app.callback(Output("output", "children"), Input("btn", "n_clicks"))
+ def error_callback(n):
+ if n and n > 0:
+ return 1 / 0 # Intentional error
+ return "No error"
+
+ dash_duo.start_server(app, **start_server_kwargs)
+ dash_duo.wait_for_text_to_equal("#output", "No error")
+ dash_duo.find_element("#btn").click()
+ dash_duo.wait_for_text_to_equal(dash_duo.devtools_error_count_locator, "1")
+
+ error0 = get_error_html(dash_duo, 0)
+ assert "in error_callback" in error0
+ assert "ZeroDivisionError" in error0
+ assert "dash/backend" not in error0 and error_msg not in error0
+
+
+@pytest.mark.parametrize(
+ "backend,fixture,input_value",
+ [
+ ("fastapi", "dash_duo", "Background FastAPI!"),
+ ("quart", "dash_duo_mp", "Background Quart!"),
+ ],
+)
+def test_backend_background_callback(request, backend, fixture, input_value):
+ dash_duo = request.getfixturevalue(fixture)
+ import diskcache
+
+ cache = diskcache.Cache("./cache")
+ from dash.background_callback import DiskcacheManager
+
+ background_callback_manager = DiskcacheManager(cache)
+
+ app = Dash(
+ __name__,
+ backend=backend,
+ background_callback_manager=background_callback_manager,
+ )
+ app.layout = html.Div(
+ [dcc.Input(id="input", value=input_value, type="text"), html.Div(id="output")]
+ )
+
+ @app.callback(
+ Output("output", "children"), Input("input", "value"), background=True
+ )
+ def update_output_bg(value):
+ return f"Background typed: {value}"
+
+ dash_duo.start_server(app)
+ dash_duo.wait_for_text_to_equal("#output", f"Background typed: {input_value}")
+ dash_duo.find_element("#input").clear()
+ dash_duo.find_element("#input").send_keys(f"{backend.title()} BG Test")
+ dash_duo.wait_for_text_to_equal(
+ "#output", f"Background typed: {backend.title()} BG Test"
+ )
+ assert dash_duo.get_logs() == []
diff --git a/tests/integration/devtools/test_devtools_error_handling.py b/tests/integration/devtools/test_devtools_error_handling.py
index 40d5731202..005bf8c335 100644
--- a/tests/integration/devtools/test_devtools_error_handling.py
+++ b/tests/integration/devtools/test_devtools_error_handling.py
@@ -109,14 +109,14 @@ def test_dveh006_long_python_errors(dash_duo):
assert "in bad_sub" not in error0
# dash and flask part of the traceback ARE included
# since we set dev_tools_prune_errors=False
- assert "dash.py" in error0
+ assert "backend" in error0 and "flask.py" in error0
assert "self.wsgi_app" in error0
error1 = get_error_html(dash_duo, 1)
assert "in update_output" in error1
assert "in bad_sub" in error1
assert "ZeroDivisionError" in error1
- assert "dash.py" in error1
+ assert "backend" in error1 and "flask.py" in error1
assert "self.wsgi_app" in error1
diff --git a/tests/integration/multi_page/test_pages_layout.py b/tests/integration/multi_page/test_pages_layout.py
index 48751021b9..a209ae4517 100644
--- a/tests/integration/multi_page/test_pages_layout.py
+++ b/tests/integration/multi_page/test_pages_layout.py
@@ -3,6 +3,7 @@
from dash import Dash, Input, State, dcc, html, Output
from dash.dash import _ID_LOCATION
from dash.exceptions import NoLayoutException
+from dash.testing.wait import until
def get_app(path1="/", path2="/layout2"):
@@ -57,7 +58,7 @@ def test_pala001_layout(dash_duo, clear_pages_state):
for page in dash.page_registry.values():
dash_duo.find_element("#" + page["id"]).click()
dash_duo.wait_for_text_to_equal("#text_" + page["id"], "text for " + page["id"])
- assert dash_duo.driver.title == page["title"], "check that page title updates"
+ until(lambda: dash_duo.driver.title == page["title"], timeout=3)
# test redirects
dash_duo.wait_for_page(url=f"{dash_duo.server_url}/v2")
diff --git a/tests/integration/multi_page/test_pages_relative_path.py b/tests/integration/multi_page/test_pages_relative_path.py
index 6c505ac3f5..24e7209a70 100644
--- a/tests/integration/multi_page/test_pages_relative_path.py
+++ b/tests/integration/multi_page/test_pages_relative_path.py
@@ -2,6 +2,7 @@
import dash
from dash import Dash, dcc, html
+from dash.testing.wait import until
def get_app(app):
@@ -70,7 +71,7 @@ def test_pare002_relative_path_with_url_base_pathname(
for page in dash.page_registry.values():
dash_br.find_element("#" + page["id"]).click()
dash_br.wait_for_text_to_equal("#text_" + page["id"], "text for " + page["id"])
- assert dash_br.driver.title == page["title"], "check that page title updates"
+ until(lambda: dash_br.driver.title == page["title"], timeout=3)
assert dash_br.get_logs() == [], "browser console should contain no error"
@@ -83,6 +84,6 @@ def test_pare003_absolute_path(dash_duo, clear_pages_state):
for page in dash.page_registry.values():
dash_duo.find_element("#" + page["id"]).click()
dash_duo.wait_for_text_to_equal("#text_" + page["id"], "text for " + page["id"])
- assert dash_duo.driver.title == page["title"], "check that page title updates"
+ until(lambda: dash_duo.driver.title == page["title"], timeout=3)
assert dash_duo.get_logs() == [], "browser console should contain no error"