diff --git a/vizro-core/examples/dev/app.py b/vizro-core/examples/dev/app.py index a15e6d73b..d8308ec8f 100644 --- a/vizro-core/examples/dev/app.py +++ b/vizro-core/examples/dev/app.py @@ -995,7 +995,7 @@ def my_custom_table(data_frame=None, chosen_columns: list[str] | None = None): class TooltipNonCrossRangeSlider(vm.RangeSlider): """Custom numeric multi-selector `TooltipNonCrossRangeSlider`.""" - type: Literal["other_range_slider"] = "other_range_slider" + type: Literal["custom_component"] = "custom_component" def build(self): """Extend existing component by calling the super build and update properties.""" @@ -1005,14 +1005,14 @@ def build(self): return range_slider_build_obj -vm.Filter.add_type("selector", TooltipNonCrossRangeSlider) +# vm.Filter.add_type("selector", TooltipNonCrossRangeSlider) # 2. Create new custom component class Jumbotron(vm.VizroBaseModel): """New custom component `Jumbotron`.""" - type: Literal["jumbotron"] = "jumbotron" + type: Literal["custom_component"] = "custom_component" title: str subtitle: str text: str @@ -1022,7 +1022,7 @@ def build(self): return html.Div([html.H2(self.title), html.H3(self.subtitle), html.P(self.text)]) -vm.Page.add_type("components", Jumbotron) +# vm.Page.add_type("components", Jumbotron) custom_components = vm.Page( title="Custom Components", @@ -1093,7 +1093,7 @@ def multiple_cards(data_frame: pd.DataFrame, n_rows: int | None = 1) -> html.Div # DASHBOARD ------------------------------------------------------------------- -components = [graphs, ag_grid, table, cards, figure, button, containers, tabs, tooltip] +components = [tabs, graphs, ag_grid, table, cards, figure, button, containers, tooltip] controls = [filters, parameters, selectors, controls_in_containers] actions = [export_data_action] layout = [grid_layout, flex_layout] @@ -1101,7 +1101,7 @@ def multiple_cards(data_frame: pd.DataFrame, n_rows: int | None = 1) -> html.Div dashboard = vm.Dashboard( title="Vizro Features", - pages=[home, *components, *controls, *actions, *layout, *extensions], + pages=[home, *controls, *actions, *layout, *extensions, *components], navigation=vm.Navigation( nav_selector=vm.NavBar( items=[ @@ -1140,6 +1140,7 @@ def multiple_cards(data_frame: pd.DataFrame, n_rows: int | None = 1) -> html.Div if __name__ == "__main__": # Move app definition outside of __main__ block for the HF demo to work + print("================ STARTING TREE BUILDING ==================") app = Vizro().build(dashboard) app.dash.layout.children.append( dbc.NavLink( @@ -1149,4 +1150,4 @@ def multiple_cards(data_frame: pd.DataFrame, n_rows: int | None = 1) -> html.Div class_name="anchor-container", ) ) - app.run() + app.run(debug=False) diff --git a/vizro-core/examples/dev/notebook/example.ipynb b/vizro-core/examples/dev/notebook/example.ipynb new file mode 100644 index 000000000..4324e117a --- /dev/null +++ b/vizro-core/examples/dev/notebook/example.ipynb @@ -0,0 +1,158 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "2a4a1132", + "metadata": {}, + "outputs": [], + "source": [ + "import vizro.models as vm\n", + "from vizro import Vizro" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "125bf3e0", + "metadata": {}, + "outputs": [], + "source": [ + "card = vm.Card(\n", + " id=\"card_1\",\n", + " text=\"Card 1\",\n", + ")\n", + "\n", + "card2 = vm.Card(\n", + " id=\"card_2\",\n", + " text=\"Card 2\",\n", + ")\n", + "\n", + "page = vm.Page(\n", + " # id=\"page_1\",\n", + " title=\"Page 1\",\n", + " components=[card, card2],\n", + ")\n", + "\n", + "dashboard = vm.Dashboard(\n", + " id=\"dashboard_1\",\n", + " title=\"Dashboard 1\",\n", + " pages=[page],\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b3c5f8eb", + "metadata": {}, + "outputs": [], + "source": [ + "from dash._callback_context import context_value\n", + "\n", + "context_value.set({})\n", + "\n", + "dashboard = vm.Dashboard.model_validate(dashboard, context={\"build_tree\": True})\n", + "app = Vizro().build(dashboard)\n", + "\n", + "# TODO: Make this work!\n", + "app.run(debug=False)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "115a8176", + "metadata": {}, + "outputs": [], + "source": [ + "dashboard._tree.print(repr=\"{node.kind} -> {node.data.type} (id={node.data.id})\")" + ] + }, + { + "cell_type": "markdown", + "id": "a13a7ee5", + "metadata": {}, + "source": [ + "It seems that a second dashboard is not yet possible, due to path collision" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5835c587", + "metadata": {}, + "outputs": [], + "source": [ + "card = vm.Card(\n", + " id=\"card_1\",\n", + " text=\"Card 1 asdf\",\n", + ")\n", + "\n", + "card2 = vm.Card(\n", + " id=\"card_2\",\n", + " text=\"Card 2 asdf\",\n", + ")\n", + "\n", + "page = vm.Page(\n", + " # id=\"page_1\",\n", + " title=\"Page 1 qwer\",\n", + " components=[card, card2],\n", + ")\n", + "\n", + "dashboard = vm.Dashboard(\n", + " id=\"dashboard_1\",\n", + " title=\"Dashboard 1 asdf\",\n", + " pages=[page],\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3e47f325", + "metadata": {}, + "outputs": [], + "source": [ + "from dash._callback_context import context_value\n", + "\n", + "context_value.set({})\n", + "\n", + "dashboard = vm.Dashboard.model_validate(dashboard, context={\"build_tree\": True})\n", + "app = Vizro().build(dashboard)\n", + "\n", + "# TODO: Make this work!\n", + "app.run(debug=False, port=8055)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "bb9f29f9", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "vizro", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.13.9" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/vizro-core/examples/scratch_dev/app.py b/vizro-core/examples/scratch_dev/app.py index 40c2be634..99d9a2658 100644 --- a/vizro-core/examples/scratch_dev/app.py +++ b/vizro-core/examples/scratch_dev/app.py @@ -1,170 +1,191 @@ -import vizro.plotly.express as px +from typing import Literal + import vizro.models as vm -import vizro.actions as va +import vizro.plotly.express as px +from dash import html from vizro import Vizro -from vizro.models.types import capture -from time import sleep +from vizro.managers import data_manager +from vizro.tables import dash_ag_grid df = px.data.iris() +data_manager["iris"] = df -page_1 = vm.Page( - title="Test dmc notification system", - layout=vm.Flex(), - components=[ - vm.Button( - icon="check_circle", - text="Success Notification", - actions=[ - va.show_notification( - text="Operation completed successfully!", - variant="success", - ) - ], - ), - vm.Button( - icon="warning", - text="Warning Notification", - actions=[ - va.show_notification( - text="Please review this warning message.", - variant="warning", - ) - ], - ), - vm.Button( - text="Error Notification", - icon="error", - actions=[ - va.show_notification( - text="An error occurred during the operation.", - variant="error", - ) - ], - ), - vm.Button( - text="Info Notification", - icon="info", - actions=[ - va.show_notification( - text="Here's some useful information for you.", - variant="info", - ) - ], - ), - vm.Button( - text="Loading Notification", - icon="hourglass_empty", - actions=[ - va.show_notification( - text="Processing your request...", - variant="progress", - ) - ], - ), - vm.Button( - text="No Auto-Close", - icon="close", - actions=[ - va.show_notification( - text="This notification will stay until you close it manually.", - title="Persistent", - variant="info", - auto_close=False, - ) - ], - ), - vm.Button( - text="Custom Icon", - icon="celebration", - actions=[ - va.show_notification( - text="Check out this new feature!", - title="New Feature", - variant="success", - icon="celebration", +data_manager["gapminder_2007"] = px.data.gapminder().query("year == 2007") + +gapminder = px.data.gapminder() +iris = px.data.iris() +tips = px.data.tips() + + +# 2. Create new custom component +class Jumbotron(vm.VizroBaseModel): + """New custom component `Jumbotron`.""" + + type: Literal["custom_component"] = "custom_component" + title: str + subtitle: str + text: str + + def build(self): + """Build the new component based on Dash components.""" + return html.Div([html.H2(self.title), html.H3(self.subtitle), html.P(self.text)]) + + +class CustomCard(vm.VizroBaseModel): + """New custom component `Card`.""" + + type: Literal["custom_component"] = "custom_component" + title: str + description: str + + def build(self): + """Build the new component based on Dash components.""" + return html.Div( + [ + html.Div( + [ + html.H4(self.title, style={"margin": "0 0 10px 0"}), + html.P(self.description, style={"margin": "0"}), + ], + style={ + "border": "1px solid #ddd", + "border-radius": "8px", + "padding": "16px", + "background-color": "#f9f9f9", + }, ) - ], + ] + ) + + +page = vm.Page( + title="My first dashboard", + components=[ + vm.Graph(figure=px.scatter(df, x="sepal_length", y="petal_width", color="species")), + vm.Graph(figure=px.histogram(df, x="sepal_width", color="species")), + Jumbotron( + title="Custom component", + subtitle="This is a subtitle", + text="This is the main body of text of the Jumbotron.", ), - vm.Button( - text="Markdown with Link", - icon="link", - actions=[ - va.show_notification( - text="Visit the [Vizro documentation](https://vizro.readthedocs.io/en/stable/) for more details!", - title="", - auto_close=False, - ) - ], + Jumbotron( + title="Custom component 2", + subtitle="This is a subtitle", + text="This is the main body of text of the Jumbotron.", ), - vm.Button( - text="1. Show Loading", - icon="hourglass_empty", - actions=[ - va.show_notification( - id="update-demo", - text="Processing your request...", - title="Processing", - variant="progress", - ) - ], + CustomCard( + title="Custom card", + description="This is a description of the custom card.", ), - vm.Button( - text="2. Update to Complete", - icon="done", - actions=[ - va.update_notification( - notification="update-demo", - text="Your request has been processed successfully!", - title="Complete", - variant="success", - ) - ], + ], + controls=[ + vm.Filter(column="species"), + ], +) + +tab_1 = vm.Container( + id="container_1", + title="Tab I", + components=[ + vm.Graph( + figure=px.bar( + "gapminder_2007", + title="Graph 1", + x="continent", + y="lifeExp", + color="continent", + ), ), - vm.Button( - text="Show Navigation Notification", - icon="arrow_forward", - actions=[ - va.show_notification( - text="Click [here](/page-two) to go to **Page 2** and explore more features!", - title="Ready to explore?", - variant="info", - auto_close=False, - ) - ], + vm.Graph( + figure=px.box( + "gapminder_2007", + title="Graph 2", + x="continent", + y="lifeExp", + color="continent", + ), ), + vm.Graph(figure=px.scatter(iris, x="sepal_width", y="petal_length"), title="Title"), + vm.AgGrid(figure=dash_ag_grid(data_frame=iris)), ], ) - -page_two = vm.Page( - id="page-two", - title="Page Two", - controls=[vm.Filter(column="species")], +tab_2 = vm.Container( + id="container_2", + title="Tab II", components=[ - vm.Graph(figure=px.histogram(df, x="sepal_length")), - vm.Button( - icon="file_download", - text="Export data notification", - actions=[ - va.show_notification( - id="export-notif", - text="Export data starting...", - title="", - variant="progress", - ), - vm.Action(function=capture("action")(lambda: sleep(2.5))()), - va.export_data(), - va.update_notification( - notification="export-notif", - text="Export data completed successfully!", - variant="success", - ), - ], + vm.Graph( + figure=px.scatter( + "gapminder_2007", + title="Graph 3", + x="gdpPercap", + y="lifeExp", + size="pop", + color="continent", + ), ), ], ) -dashboard = vm.Dashboard(pages=[page_1, page_two], title="Test Dashboard") +tabs = vm.Page( + id="page_1", + title="Tabs", + components=[vm.Tabs(id="tabs_1", tabs=[tab_1, tab_2])], + controls=[vm.Filter(id="filter_1", column="continent")], +) + +dashboard = vm.Dashboard(id="dashboard_1", pages=[tabs]) + +# Same configuration as JSON +# dashboard_config = { +# "type": "dashboard", +# "pages": [ +# { +# "type": "page", +# "title": "My first dashboard", +# "components": [ +# { +# "type": "graph", +# "figure": { +# "_target_": "scatter", +# "data_frame": "iris", +# "x": "sepal_length", +# "y": "petal_width", +# "color": "species", +# }, +# }, +# { +# "type": "graph", +# "figure": { +# "_target_": "histogram", +# "data_frame": "iris", +# "x": "sepal_width", +# "color": "species", +# }, +# }, +# ], +# "controls": [ +# { +# "type": "filter", +# "column": "species", +# } +# ], +# } +# ], +# } + +# dashboard = vm.Dashboard.model_validate(dashboard_config) if __name__ == "__main__": - Vizro().build(dashboard).run() + app = Vizro().build(dashboard) + # TODO: Do these tests elsewhere + # for node in dashboard._tree: + # has_tree = hasattr(node.data, "_tree") + # if not has_tree: + # print(f"WARNING: {node.kind} (id={node.data.id}) missing ._tree attribute") + # assert all( + # dashboard._tree[model.id].data is model + # for model in [dashboard] + dashboard.pages + [comp for page in dashboard.pages for comp in page.components] + # ) + + app.run() diff --git a/vizro-core/pyproject.toml b/vizro-core/pyproject.toml index 7666d3d6b..3f7eeb0dc 100644 --- a/vizro-core/pyproject.toml +++ b/vizro-core/pyproject.toml @@ -30,7 +30,8 @@ dependencies = [ "black", "autoflake", "packaging", - "python-box" + "python-box", + "nutree" ] description = "Vizro is a low-code framework for building high-quality data visualization apps." dynamic = ["version"] diff --git a/vizro-core/src/vizro/_vizro.py b/vizro-core/src/vizro/_vizro.py index 6c3e934f6..30142e380 100644 --- a/vizro-core/src/vizro/_vizro.py +++ b/vizro-core/src/vizro/_vizro.py @@ -139,6 +139,9 @@ def build(self, dashboard: Dashboard) -> Self: pio.templates.default = dashboard.theme # Note that model instantiation and pre_build are independent of Dash. + # TMP: Set the tree + dashboard = dashboard.__class__.model_validate(dashboard, context={"build_tree": True}) + model_manager._dashboard_tree = dashboard._tree self._pre_build() self.dash.layout = dashboard.build() @@ -204,6 +207,8 @@ def _pre_build(): # Any models that are created during the pre-build process *will not* themselves have pre_build run on them. # In future may add a second pre_build loop after the first one. + # TODO: Things fail here because the MM copy of the model is outdated - fix in next iteration + # This is also the reason why this is not replicated in fake Vizro for filter in cast(Iterable[Filter], model_manager._get_models(Filter)): # Run pre_build on all filters first, then on all other models. This handles dependency between Filter # and Page pre_build and ensures that filters are pre-built before the Page objects that use them. @@ -306,3 +311,11 @@ def _make_resource_spec(path: Path) -> _ResourceType: resource_spec["dynamic"] = True return resource_spec + + +"""FURTHER PLAN OF ACTION + +- convert MM from dictionary to just tree reference, populate at correct moment +- then replace all methods with tree lookups +- then remove the MM from its global state +""" diff --git a/vizro-core/src/vizro/managers/_model_manager.py b/vizro-core/src/vizro/managers/_model_manager.py index 4e4e586ea..637b2e662 100644 --- a/vizro-core/src/vizro/managers/_model_manager.py +++ b/vizro-core/src/vizro/managers/_model_manager.py @@ -5,7 +5,7 @@ from collections.abc import Collection, Generator, Iterable, Mapping from typing import TYPE_CHECKING, TypeVar, cast -from vizro.managers._managers_utils import _state_modifier +from nutree.typed_tree import TypedTree if TYPE_CHECKING: from vizro.models import Page, VizroBaseModel @@ -22,34 +22,20 @@ class FIGURE_MODELS: pass +# TODO: Re-implement the dupicate ID error and investigate further why things (atm) are working nonetheless: +# Duplciate ID on Tabs Containers caused the duplicate not to appear in the tree, BUT the page worked as intended, why? +# Very likely because we still get to the model via the parent model when we iterator over the tabs field class DuplicateIDError(ValueError): """Useful for providing a more explicit error message when a model has id set automatically, e.g. Page.""" class ModelManager: def __init__(self): - self.__models: dict[ModelID, VizroBaseModel] = {} self._frozen_state = False - - # TODO: Consider storing "page_id" or "parent_model_id" and make searching helper methods easier? - @_state_modifier - def __setitem__(self, model_id: ModelID, model: Model): - if model_id in self.__models: - raise DuplicateIDError( - f"Model with id={model_id} already exists. Models must have a unique id across the whole dashboard. " - f"If you are working from a Jupyter Notebook, please either restart the kernel, or " - f"use 'from vizro import Vizro; Vizro._reset()`." - ) - self.__models[model_id] = model - - @_state_modifier - def __delitem__(self, model_id: ModelID): - # Only required to handle legacy actions and could be removed when those are no longer needed. - del self.__models[model_id] + self._dashboard_tree: TypedTree | None = None def __getitem__(self, model_id: ModelID) -> VizroBaseModel: - # Do we need to return deepcopy(self.__models[model_id]) to avoid adjusting element by accident? - return self.__models[model_id] + return self._dashboard_tree.find_first(data_id=model_id).data def __iter__(self) -> Generator[ModelID, None, None]: """Iterates through all models. @@ -58,7 +44,8 @@ def __iter__(self) -> Generator[ModelID, None, None]: """ # TODO: should this yield models rather than model IDs? Should model_manager be more like set with a special # lookup by model ID or more like dictionary? - yield from self.__models + for node in self._dashboard_tree.iterator(): + yield node.data.id def _get_models( self, @@ -71,31 +58,42 @@ def _get_models( If `root_model` is specified, return only models that are descendants of the given `root_model`. """ import vizro.models as vm + from vizro.models import VizroBaseModel if model_type is FIGURE_MODELS: model_type = (vm.Graph, vm.AgGrid, vm.Table, vm.Figure) # type: ignore[assignment] - models = self.__get_model_children(root_model) if root_model is not None else self.__models.values() # type: ignore[type-var] + + # Get models from tree based on root_model + if root_model is None: + # Iterate entire tree + models = (n.data for n in self._dashboard_tree.iterator() if isinstance(n.data, VizroBaseModel)) + elif isinstance(root_model, VizroBaseModel): + # Single model - get its descendants + models = self.__get_model_children(root_model) + elif isinstance(root_model, Mapping): + # Mapping - extract VizroBaseModel instances and get descendants for each + models = [] + for child in root_model.values(): + if isinstance(child, VizroBaseModel): + models.extend(list(self.__get_model_children(child))) + elif isinstance(root_model, Collection) and not isinstance(root_model, str): + # Collection - extract VizroBaseModel instances and get descendants for each + models = [] + for child in root_model: + if isinstance(child, VizroBaseModel): + models.extend(list(self.__get_model_children(child))) + else: + return # return empty generator # Convert to list to avoid changing size when looping through at runtime. - for model in list(models): + for model in models: if model_type is None or isinstance(model, model_type): yield model # type: ignore[misc] def __get_model_children(self, model: Model) -> Generator[Model, None, None]: """Iterates through children of `model` with depth-first pre-order traversal.""" - from vizro.models import VizroBaseModel - - if isinstance(model, VizroBaseModel): - yield model - for model_field in model.__class__.model_fields: - yield from self.__get_model_children(getattr(model, model_field)) - elif isinstance(model, Mapping): - # We don't look through keys because Vizro models aren't hashable. - for child in model.values(): - yield from self.__get_model_children(child) - elif isinstance(model, Collection) and not isinstance(model, str): - for child in model: - yield from self.__get_model_children(child) + node = self._dashboard_tree.find_first(data_id=model.id) + yield from (n.data for n in node.iterator(add_self=True)) def _get_model_page(self, model: Model) -> Page: # type: ignore[return] """Gets the page containing `model`.""" diff --git a/vizro-core/src/vizro/models/_base.py b/vizro-core/src/vizro/models/_base.py index 714d52cef..6d12f0d48 100644 --- a/vizro-core/src/vizro/models/_base.py +++ b/vizro-core/src/vizro/models/_base.py @@ -1,26 +1,39 @@ +from __future__ import annotations + import inspect import logging import random import textwrap import uuid -from typing import Annotated, Any, Union, cast, get_args, get_origin +from collections.abc import Mapping +from types import SimpleNamespace +from typing import Annotated, Any, Literal, Self, TypeVar, Union, cast, get_args, get_origin import autoflake import black +from nutree.typed_tree import TypedTree from pydantic import ( BaseModel, ConfigDict, Field, + ModelWrapValidatorHandler, + PrivateAttr, SerializationInfo, SerializerFunctionWrapHandler, + ValidatorFunctionWrapHandler, + field_validator, model_serializer, + model_validator, ) from pydantic.fields import FieldInfo +from pydantic_core.core_schema import ValidationInfo from vizro.managers import model_manager -from vizro.models._models_utils import REPLACEMENT_STRINGS, _log_call +from vizro.models._models_utils import REPLACEMENT_STRINGS from vizro.models.types import ModelID +Model = TypeVar("Model", bound="VizroBaseModel") + # As done for Dash components in dash.development.base_component, fixing the random seed is required to make sure that # the randomly generated model ID for the same model matches up across workers when running gunicorn without --preload. rd = random.Random(0) # noqa: S311 @@ -207,6 +220,25 @@ def _split_types(type_annotation: type[Any]) -> type[Any]: ) +def _validate_with_tree_context( + model: Model, + parent_model: Model, + field_name: str, +) -> Model: + """Validate a model instance with tree-building context.""" + # if not isinstance(model, VizroBaseModel): + # raise ValueError(f"Model must be a subclass of VizroBaseModel: {model}") + return type(model).model_validate( + model, + context={ + "build_tree": True, + "parent_model": parent_model, + "field_stack": [field_name], + "id_stack": [parent_model.id], + }, + ) + + class VizroBaseModel(BaseModel): """All Vizro models inherit from this class. @@ -216,9 +248,16 @@ class VizroBaseModel(BaseModel): Args: id (ModelID): ID to identify model. Must be unique throughout the whole dashboard. When no ID is chosen, ID will be automatically generated. + type: Type identifier for the model. Defaults to "vizro_base_model" for the base class. + Subclasses should override with their specific Literal type. + Custom components should set type: str = "custom_component" or type: Literal["custom_component"] = "custom_component" """ + # Default type for base model. Subclasses should override with their specific Literal type. + # Custom components should set type: str = "custom_component" or type: Literal["custom_component"] = "custom_component" + type: Literal["vizro_base_model"] = Field(default="vizro_base_model") + id: Annotated[ ModelID, Field( @@ -228,10 +267,172 @@ class VizroBaseModel(BaseModel): validate_default=True, ), ] - - @_log_call - def model_post_init(self, context: Any) -> None: - model_manager[self.id] = self + _tree: TypedTree | None = PrivateAttr(None) # initialised in model_after + + # Next TODO: + # why do we end up with ._tree = None + # is it legit to just copy all private attributes + # why should that wrap validator be the last one, is it because otherwise it misses things when not + # having finished? + # We should document the error to also discuss with rest of the team... + # Idea for tomorrow: why don't we check the diff between before and after too? + + # @_log_call + # def model_post_init(self, context: Any) -> None: + # model_manager[self.id] = self + + # TODO: is that really used? + @staticmethod + def _ensure_model_in_tree(model: VizroBaseModel, context: dict[str, Any]) -> VizroBaseModel: + """Revalidate a VizroBaseModel instance if it hasn't been added to the tree yet.""" + has_tree = hasattr(model, "_tree") + tree_is_none = getattr(model, "_tree", None) is None + if not has_tree or tree_is_none: + # Revalidate with build_tree context to ensure tree node is created + return model.__class__.model_validate(model, context=context) + return model + + @staticmethod + def _ensure_models_in_tree(validated_stuff: Any, context: dict[str, Any]) -> Any: + """Recursively ensure all VizroBaseModel instances in a structure are added to the tree.""" + if isinstance(validated_stuff, VizroBaseModel): + return VizroBaseModel._ensure_model_in_tree(validated_stuff, context) + elif isinstance(validated_stuff, list): + return [VizroBaseModel._ensure_models_in_tree(item, context) for item in validated_stuff] + elif isinstance(validated_stuff, Mapping) and not isinstance(validated_stuff, str): + # Note: str is a Mapping in Python, so we exclude it + return type(validated_stuff)( + {key: VizroBaseModel._ensure_models_in_tree(value, context) for key, value in validated_stuff.items()} + ) + return validated_stuff + + @field_validator("*", mode="wrap") + @classmethod + def build_tree_field_wrap( + cls, + value: Any, + handler: ValidatorFunctionWrapHandler, + info: ValidationInfo, + ) -> Any: + if info.context is not None and "build_tree" in info.context: + #### Field stack #### + if "id_stack" not in info.context: + info.context["id_stack"] = [] + if "field_stack" not in info.context: + info.context["field_stack"] = [] + if info.field_name == "id": + info.context["id_stack"].append(value) + else: + info.context["id_stack"].append(info.data.get("id", "no id")) + info.context["field_stack"].append(info.field_name) + + #### Validation #### + validated_stuff = handler(value) + + if info.context is not None and "build_tree" in info.context: + #### Ensure VizroBaseModel instances are added to tree #### + # This handles the case where custom components match 'Any' in discriminated unions + # and might not go through full revalidation + # Note: field_stack and id_stack are still in place here (before the pop below) + # so build_tree_model_wrap will have the correct context + validated_stuff = VizroBaseModel._ensure_models_in_tree(validated_stuff, info.context) + + #### Field stack cleanup #### + # Pop after revalidation so the stacks are available during revalidation + info.context["id_stack"].pop() + info.context["field_stack"].pop() + return validated_stuff + + @model_validator(mode="wrap") + @classmethod + def build_tree_model_wrap(cls, data: Any, handler: ModelWrapValidatorHandler[Self], info: ValidationInfo) -> Self: + # PRIVATE ATTR PRESERVATION + # What could go wrong here? + # Potentially if during validation of a subfield model, i want to set something in a parent + # private attribute, but what i set then get's overwritten later by the saving mechanism? + private_attrs = {} + # If data is already an instance of this model, capture PrivateAttr values + if isinstance(data, cls): + # Get all PrivateAttr fields from the model's __private_attributes__ + if hasattr(cls, "__private_attributes__"): + for attr_name in cls.__private_attributes__.keys(): + # Check if the attribute has been set on the instance + if hasattr(data, attr_name): + # print(f"Capturing PrivateAttr: {attr_name}") + try: + value = getattr(data, attr_name) + # Capture the value (including None if explicitly set) + private_attrs[attr_name] = value + except AttributeError: + pass + + if info.context is not None and "build_tree" in info.context: + #### ID #### + if isinstance(data, dict): + model_id = data.get("id") + else: + model_id = getattr(data, "id", None) + + if model_id is None: + raise ValueError(f"Cannot determine model id for data: {data}") + + #### Level and indentation #### + if "level" not in info.context: + info.context["level"] = 0 + indent = info.context["level"] * " " * 4 + info.context["level"] += 1 + + #### Tree #### + # print( + # f"{indent}{cls.__name__} Before validation: {info.context['field_stack'] if 'field_stack' in info.context else 'no field stack'} with model id {model_id}" + # ) + + # Skip addition if node already exists in tree (reason not yet understood) + if not ("tree" in info.context and info.context["tree"].find_first(data_id=model_id)): + if "parent_model" in info.context: # is that from pre_build? + info.context["tree"] = info.context["parent_model"]._tree + tree = info.context["tree"] + tree[info.context["parent_model"].id].add( + SimpleNamespace(id=model_id), kind=info.context["field_stack"][-1] + ) + elif "tree" not in info.context: + tree = TypedTree("Root", calc_data_id=lambda tree, data: data.id) + tree.add(SimpleNamespace(id=model_id), kind="dashboard") # TODO: make this more general + info.context["tree"] = tree + else: + tree = info.context["tree"] + # in words: add a node as children to the parent (so id one higher up), but add as kind + # the field in which you currently are + # ID STACK and FIELD STACK are different "levels" of the tree. + tree[info.context["id_stack"][-1]].add( + SimpleNamespace(id=model_id), kind=info.context["field_stack"][-1] + ) + + #### Validation #### + validated_stuff = handler(data) + + # Restore PrivateAttr values if we captured any + for attr_name, value in private_attrs.items(): + setattr(validated_stuff, attr_name, value) + + if info.context is not None and "build_tree" in info.context: + #### Replace placeholder nodes and propagate tree to all models #### + info.context["tree"][validated_stuff.id].set_data(validated_stuff) + validated_stuff._tree = info.context["tree"] + + #### Level and indentation #### + info.context["level"] -= 1 + indent = info.context["level"] * " " * 4 + # print(f"{indent}{cls.__name__} After validation: {info.context['field_stack']}") + elif hasattr(data, "_tree") and data._tree is not None: + #### Revalidation case: model already has a tree (e.g., during assignment) #### + # Inherit the tree from the original instance + validated_stuff._tree = data._tree + # Update the tree node to point to the NEW validated instance + validated_stuff._tree[validated_stuff.id].set_data(validated_stuff) + print(f"--> Revalidation: Updated tree node for {validated_stuff.id} <--") + + return validated_stuff # Previously in V1, we used to have an overwritten `.dict` method, that would add __vizro_model__ to the dictionary # if called in the correct context. @@ -345,4 +546,5 @@ def _to_python(self, extra_imports: set[str] | None = None, extra_callable_defs: model_config = ConfigDict( extra="forbid", # Good for spotting user typos and being strict. validate_assignment=True, # Run validators when a field is assigned after model instantiation. + revalidate_instances="always", ) diff --git a/vizro-core/src/vizro/models/_components/ag_grid.py b/vizro-core/src/vizro/models/_components/ag_grid.py index 8c80f3afc..272c0eceb 100644 --- a/vizro-core/src/vizro/models/_components/ag_grid.py +++ b/vizro-core/src/vizro/models/_components/ag_grid.py @@ -19,7 +19,14 @@ warn_description_without_title, ) from vizro.models._tooltip import coerce_str_to_tooltip -from vizro.models.types import ActionsType, CapturedCallable, MultiValueType, _IdProperty, _validate_captured_callable +from vizro.models.types import ( + ActionsType, + CapturedCallable, + MultiValueType, + _IdProperty, + _validate_captured_callable, + make_discriminated_union, +) logger = logging.getLogger(__name__) @@ -71,7 +78,7 @@ class AgGrid(VizroBaseModel): # TODO: ideally description would have json_schema_input_type=str | Tooltip attached to the BeforeValidator, # but this requires pydantic >= 2.9. description: Annotated[ - Tooltip | None, + make_discriminated_union(Tooltip) | None, BeforeValidator(coerce_str_to_tooltip), AfterValidator(warn_description_without_title), Field( diff --git a/vizro-core/src/vizro/models/_components/button.py b/vizro-core/src/vizro/models/_components/button.py index 3e982879b..30c5f8e48 100644 --- a/vizro-core/src/vizro/models/_components/button.py +++ b/vizro-core/src/vizro/models/_components/button.py @@ -8,7 +8,7 @@ from vizro.models import Tooltip, VizroBaseModel from vizro.models._models_utils import _log_call, make_actions_chain, validate_icon from vizro.models._tooltip import coerce_str_to_tooltip -from vizro.models.types import ActionsType, _IdProperty +from vizro.models.types import ActionsType, _IdProperty, make_discriminated_union class Button(VizroBaseModel): @@ -51,7 +51,7 @@ class Button(VizroBaseModel): # TODO: ideally description would have json_schema_input_type=str | Tooltip attached to the BeforeValidator, # but this requires pydantic >= 2.9. description: Annotated[ - Tooltip | None, + make_discriminated_union(Tooltip) | None, BeforeValidator(coerce_str_to_tooltip), # AfterValidator(warn_description_without_title) is not needed here because either 'text' or 'icon' argument # is mandatory. diff --git a/vizro-core/src/vizro/models/_components/card.py b/vizro-core/src/vizro/models/_components/card.py index 1041c7952..b7948f951 100644 --- a/vizro-core/src/vizro/models/_components/card.py +++ b/vizro-core/src/vizro/models/_components/card.py @@ -8,7 +8,7 @@ from vizro.models import Tooltip, VizroBaseModel from vizro.models._models_utils import _log_call, make_actions_chain from vizro.models._tooltip import coerce_str_to_tooltip -from vizro.models.types import ActionsType, _IdProperty +from vizro.models.types import ActionsType, _IdProperty, make_discriminated_union class Card(VizroBaseModel): @@ -57,7 +57,7 @@ class Card(VizroBaseModel): # TODO: ideally description would have json_schema_input_type=str | Tooltip attached to the BeforeValidator, # but this requires pydantic >= 2.9. description: Annotated[ - Tooltip | None, + make_discriminated_union(Tooltip) | None, BeforeValidator(coerce_str_to_tooltip), Field( default=None, diff --git a/vizro-core/src/vizro/models/_components/container.py b/vizro-core/src/vizro/models/_components/container.py index f3594830b..42efb98e2 100644 --- a/vizro-core/src/vizro/models/_components/container.py +++ b/vizro-core/src/vizro/models/_components/container.py @@ -20,7 +20,7 @@ warn_description_without_title, ) from vizro.models._tooltip import coerce_str_to_tooltip -from vizro.models.types import ComponentType, ControlType, LayoutType, _IdProperty +from vizro.models.types import ComponentType, ControlType, LayoutType, _IdProperty, make_discriminated_union # TODO: this could be done with default_factory once we bump to pydantic>=2.10.0. @@ -85,7 +85,7 @@ class Container(VizroBaseModel): # TODO: ideally description would have json_schema_input_type=str | Tooltip attached to the BeforeValidator, # but this requires pydantic >= 2.9. description: Annotated[ - Tooltip | None, + make_discriminated_union(Tooltip) | None, BeforeValidator(coerce_str_to_tooltip), AfterValidator(warn_description_without_title), Field( diff --git a/vizro-core/src/vizro/models/_components/form/_text_area.py b/vizro-core/src/vizro/models/_components/form/_text_area.py index 449ae231c..4b7e6d701 100755 --- a/vizro-core/src/vizro/models/_components/form/_text_area.py +++ b/vizro-core/src/vizro/models/_components/form/_text_area.py @@ -7,7 +7,7 @@ from vizro.models import Tooltip, VizroBaseModel from vizro.models._models_utils import _log_call, warn_description_without_title from vizro.models._tooltip import coerce_str_to_tooltip -from vizro.models.types import ActionsType, _IdProperty +from vizro.models.types import ActionsType, _IdProperty, make_discriminated_union class TextArea(VizroBaseModel): @@ -27,7 +27,7 @@ class TextArea(VizroBaseModel): # TODO: before making public consider naming this field (or giving an alias) label instead of title title: str = Field(default="", description="Title to be displayed") description: Annotated[ - Tooltip | None, + make_discriminated_union(Tooltip) | None, BeforeValidator(coerce_str_to_tooltip), AfterValidator(warn_description_without_title), Field( diff --git a/vizro-core/src/vizro/models/_components/form/_user_input.py b/vizro-core/src/vizro/models/_components/form/_user_input.py index c847401ea..c5858fd79 100644 --- a/vizro-core/src/vizro/models/_components/form/_user_input.py +++ b/vizro-core/src/vizro/models/_components/form/_user_input.py @@ -7,7 +7,7 @@ from vizro.models import Tooltip, VizroBaseModel from vizro.models._models_utils import _log_call, warn_description_without_title from vizro.models._tooltip import coerce_str_to_tooltip -from vizro.models.types import ActionsType, _IdProperty +from vizro.models.types import ActionsType, _IdProperty, make_discriminated_union class UserInput(VizroBaseModel): @@ -27,7 +27,7 @@ class UserInput(VizroBaseModel): # TODO: before making public consider naming this field (or giving an alias) label instead of title title: str = Field(default="", description="Title to be displayed") description: Annotated[ - Tooltip | None, + make_discriminated_union(Tooltip) | None, BeforeValidator(coerce_str_to_tooltip), AfterValidator(warn_description_without_title), Field( diff --git a/vizro-core/src/vizro/models/_components/form/checklist.py b/vizro-core/src/vizro/models/_components/form/checklist.py index 47e399d55..bbb931799 100644 --- a/vizro-core/src/vizro/models/_components/form/checklist.py +++ b/vizro-core/src/vizro/models/_components/form/checklist.py @@ -17,7 +17,7 @@ warn_description_without_title, ) from vizro.models._tooltip import coerce_str_to_tooltip -from vizro.models.types import ActionsType, MultiValueType, OptionsType, _IdProperty +from vizro.models.types import ActionsType, MultiValueType, OptionsType, _IdProperty, make_discriminated_union class Checklist(VizroBaseModel): @@ -57,7 +57,7 @@ class Checklist(VizroBaseModel): "options with a single click.", ) description: Annotated[ - Tooltip | None, + make_discriminated_union(Tooltip) | None, BeforeValidator(coerce_str_to_tooltip), AfterValidator(warn_description_without_title), Field( diff --git a/vizro-core/src/vizro/models/_components/form/date_picker.py b/vizro-core/src/vizro/models/_components/form/date_picker.py index f0e6c8c88..645c5601b 100644 --- a/vizro-core/src/vizro/models/_components/form/date_picker.py +++ b/vizro-core/src/vizro/models/_components/form/date_picker.py @@ -15,7 +15,7 @@ warn_description_without_title, ) from vizro.models._tooltip import coerce_str_to_tooltip -from vizro.models.types import ActionsType, _IdProperty +from vizro.models.types import ActionsType, _IdProperty, make_discriminated_union class DatePicker(VizroBaseModel): @@ -64,7 +64,7 @@ class DatePicker(VizroBaseModel): # TODO: ideally description would have json_schema_input_type=str | Tooltip attached to the BeforeValidator, # but this requires pydantic >= 2.9. description: Annotated[ - Tooltip | None, + make_discriminated_union(Tooltip) | None, BeforeValidator(coerce_str_to_tooltip), AfterValidator(warn_description_without_title), Field( diff --git a/vizro-core/src/vizro/models/_components/form/dropdown.py b/vizro-core/src/vizro/models/_components/form/dropdown.py index 607fdd780..34e04a25d 100755 --- a/vizro-core/src/vizro/models/_components/form/dropdown.py +++ b/vizro-core/src/vizro/models/_components/form/dropdown.py @@ -21,6 +21,7 @@ OptionsType, SingleValueType, _IdProperty, + make_discriminated_union, ) @@ -93,7 +94,7 @@ class Dropdown(VizroBaseModel): # TODO: ideally description would have json_schema_input_type=str | Tooltip attached to the BeforeValidator, # but this requires pydantic >= 2.9. description: Annotated[ - Tooltip | None, + make_discriminated_union(Tooltip) | None, BeforeValidator(coerce_str_to_tooltip), Field( default=None, diff --git a/vizro-core/src/vizro/models/_components/form/radio_items.py b/vizro-core/src/vizro/models/_components/form/radio_items.py index 9050601a4..cc9c99815 100644 --- a/vizro-core/src/vizro/models/_components/form/radio_items.py +++ b/vizro-core/src/vizro/models/_components/form/radio_items.py @@ -13,7 +13,7 @@ ) from vizro.models._models_utils import _log_call, make_actions_chain from vizro.models._tooltip import coerce_str_to_tooltip -from vizro.models.types import ActionsType, OptionsType, SingleValueType, _IdProperty +from vizro.models.types import ActionsType, OptionsType, SingleValueType, _IdProperty, make_discriminated_union class RadioItems(VizroBaseModel): @@ -47,7 +47,7 @@ class RadioItems(VizroBaseModel): # TODO: ideally description would have json_schema_input_type=str | Tooltip attached to the BeforeValidator, # but this requires pydantic >= 2.9. description: Annotated[ - Tooltip | None, + make_discriminated_union(Tooltip) | None, BeforeValidator(coerce_str_to_tooltip), Field( default=None, diff --git a/vizro-core/src/vizro/models/_components/form/range_slider.py b/vizro-core/src/vizro/models/_components/form/range_slider.py index 8a64ef8e7..47503e272 100644 --- a/vizro-core/src/vizro/models/_components/form/range_slider.py +++ b/vizro-core/src/vizro/models/_components/form/range_slider.py @@ -18,7 +18,7 @@ warn_description_without_title, ) from vizro.models._tooltip import coerce_str_to_tooltip -from vizro.models.types import ActionsType, _IdProperty +from vizro.models.types import ActionsType, _IdProperty, make_discriminated_union class RangeSlider(VizroBaseModel): @@ -69,7 +69,7 @@ class RangeSlider(VizroBaseModel): # TODO: ideally description would have json_schema_input_type=str | Tooltip attached to the BeforeValidator, # but this requires pydantic >= 2.9. description: Annotated[ - Tooltip | None, + make_discriminated_union(Tooltip) | None, BeforeValidator(coerce_str_to_tooltip), AfterValidator(warn_description_without_title), Field( diff --git a/vizro-core/src/vizro/models/_components/form/slider.py b/vizro-core/src/vizro/models/_components/form/slider.py index 63aa57e82..98e170ca8 100644 --- a/vizro-core/src/vizro/models/_components/form/slider.py +++ b/vizro-core/src/vizro/models/_components/form/slider.py @@ -18,7 +18,7 @@ warn_description_without_title, ) from vizro.models._tooltip import coerce_str_to_tooltip -from vizro.models.types import ActionsType, _IdProperty +from vizro.models.types import ActionsType, _IdProperty, make_discriminated_union class Slider(VizroBaseModel): @@ -69,7 +69,7 @@ class Slider(VizroBaseModel): # TODO: ideally description would have json_schema_input_type=str | Tooltip attached to the BeforeValidator, # but this requires pydantic >= 2.9. description: Annotated[ - Tooltip | None, + make_discriminated_union(Tooltip) | None, BeforeValidator(coerce_str_to_tooltip), AfterValidator(warn_description_without_title), Field( diff --git a/vizro-core/src/vizro/models/_components/form/switch.py b/vizro-core/src/vizro/models/_components/form/switch.py index 0801914f7..9dd1b5f24 100755 --- a/vizro-core/src/vizro/models/_components/form/switch.py +++ b/vizro-core/src/vizro/models/_components/form/switch.py @@ -8,7 +8,7 @@ from vizro.models import Tooltip, VizroBaseModel from vizro.models._models_utils import make_actions_chain, warn_description_without_title from vizro.models._tooltip import coerce_str_to_tooltip -from vizro.models.types import ActionsType, _IdProperty +from vizro.models.types import ActionsType, _IdProperty, make_discriminated_union class Switch(VizroBaseModel): @@ -43,7 +43,7 @@ class Switch(VizroBaseModel): # TODO: ideally description would have json_schema_input_type=str | Tooltip attached to the BeforeValidator, # but this requires pydantic >= 2.9. description: Annotated[ - Tooltip | None, + make_discriminated_union(Tooltip) | None, BeforeValidator(coerce_str_to_tooltip), AfterValidator(warn_description_without_title), Field( diff --git a/vizro-core/src/vizro/models/_components/graph.py b/vizro-core/src/vizro/models/_components/graph.py index 7d25c99db..a581243ce 100644 --- a/vizro-core/src/vizro/models/_components/graph.py +++ b/vizro-core/src/vizro/models/_components/graph.py @@ -24,12 +24,19 @@ ) from vizro.models._tooltip import coerce_str_to_tooltip from vizro.models.types import ( + ( ActionsType, + CapturedCallable, + ModelID, + MultiValueType, _IdProperty, + _validate_captured_callable, + make_discriminated_union, +), ) logger = logging.getLogger(__name__) @@ -85,7 +92,7 @@ class Graph(VizroBaseModel): # TODO: ideally description would have json_schema_input_type=str | Tooltip attached to the BeforeValidator, # but this requires pydantic >= 2.9. description: Annotated[ - Tooltip | None, + make_discriminated_union(Tooltip) | None, BeforeValidator(coerce_str_to_tooltip), AfterValidator(warn_description_without_title), Field( diff --git a/vizro-core/src/vizro/models/_components/table.py b/vizro-core/src/vizro/models/_components/table.py index f765ff154..de09d879f 100644 --- a/vizro-core/src/vizro/models/_components/table.py +++ b/vizro-core/src/vizro/models/_components/table.py @@ -18,7 +18,13 @@ warn_description_without_title, ) from vizro.models._tooltip import coerce_str_to_tooltip -from vizro.models.types import ActionsType, CapturedCallable, _IdProperty, _validate_captured_callable +from vizro.models.types import ( + ActionsType, + CapturedCallable, + _IdProperty, + _validate_captured_callable, + make_discriminated_union, +) logger = logging.getLogger(__name__) @@ -65,7 +71,7 @@ class Table(VizroBaseModel): # TODO: ideally description would have json_schema_input_type=str | Tooltip attached to the BeforeValidator, # but this requires pydantic >= 2.9. description: Annotated[ - Tooltip | None, + make_discriminated_union(Tooltip) | None, BeforeValidator(coerce_str_to_tooltip), AfterValidator(warn_description_without_title), Field( diff --git a/vizro-core/src/vizro/models/_components/tabs.py b/vizro-core/src/vizro/models/_components/tabs.py index 0835348d3..89c2a9ae6 100644 --- a/vizro-core/src/vizro/models/_components/tabs.py +++ b/vizro-core/src/vizro/models/_components/tabs.py @@ -9,7 +9,7 @@ from vizro.models import Tooltip, VizroBaseModel from vizro.models._models_utils import _log_call, warn_description_without_title from vizro.models._tooltip import coerce_str_to_tooltip -from vizro.models.types import _IdProperty +from vizro.models.types import _IdProperty, make_discriminated_union if TYPE_CHECKING: from vizro.models._components import Container @@ -40,7 +40,7 @@ class Tabs(VizroBaseModel): tabs: conlist(Annotated[Container, AfterValidator(validate_tab_has_title)], min_length=1) # type: ignore[valid-type] title: str = Field(default="", description="Title displayed above Tabs.") description: Annotated[ - Tooltip | None, + make_discriminated_union(Tooltip) | None, BeforeValidator(coerce_str_to_tooltip), AfterValidator(warn_description_without_title), Field( diff --git a/vizro-core/src/vizro/models/_controls/filter.py b/vizro-core/src/vizro/models/_controls/filter.py index baec01b68..8444e72e2 100644 --- a/vizro-core/src/vizro/models/_controls/filter.py +++ b/vizro-core/src/vizro/models/_controls/filter.py @@ -15,6 +15,7 @@ from vizro.managers._data_manager import DataSourceName, _DynamicData from vizro.managers._model_manager import FIGURE_MODELS from vizro.models import VizroBaseModel +from vizro.models._base import _validate_with_tree_context from vizro.models._components.form import DatePicker, Dropdown, RangeSlider, Switch from vizro.models._controls._controls_utils import ( SELECTORS, @@ -238,7 +239,10 @@ def pre_build(self): # Set default selector according to column type. self._column_type = self._validate_column_type(targeted_data) - self.selector = self.selector or DEFAULT_SELECTORS[self._column_type]() + if not self.selector: + self.selector = _validate_with_tree_context( + DEFAULT_SELECTORS[self._column_type](), parent_model=self, field_name="selector" + ) self.selector.title = self.selector.title or self.column.title() if isinstance(self.selector, DISALLOWED_SELECTORS.get(self._column_type, ())): @@ -287,12 +291,16 @@ def pre_build(self): filter_function = _filter_isin self.selector.actions = [ - _filter( - id=f"{FILTER_ACTION_PREFIX}_{self.id}", - column=self.column, - filter_function=filter_function, - targets=self.targets, - ), + _validate_with_tree_context( + _filter( + id=f"{FILTER_ACTION_PREFIX}_{self.id}", + column=self.column, + filter_function=filter_function, + targets=self.targets, + ), + parent_model=self, + field_name="actions", + ) ] # A set of properties unique to selector (inner object) that are not present in html.Div (outer build wrapper). diff --git a/vizro-core/src/vizro/models/_controls/parameter.py b/vizro-core/src/vizro/models/_controls/parameter.py index f426845f4..8d79d884b 100644 --- a/vizro-core/src/vizro/models/_controls/parameter.py +++ b/vizro-core/src/vizro/models/_controls/parameter.py @@ -8,6 +8,7 @@ from vizro.actions._parameter_action import _parameter from vizro.managers import model_manager from vizro.models import VizroBaseModel +from vizro.models._base import _validate_with_tree_context from vizro.models._controls._controls_utils import ( _is_categorical_selector, _is_numerical_temporal_selector, @@ -77,17 +78,16 @@ class Parameter(VizroBaseModel): """ type: Literal["parameter"] = "parameter" - targets: Annotated[ # TODO[MS]: check if the double annotation is the best way to do this - list[ - Annotated[ - str, - AfterValidator(check_dot_notation), - AfterValidator(check_data_frame_as_target_argument), - Field(description="Targets in the form of `.`."), - ] - ], - AfterValidator(check_duplicate_parameter_target), - ] + targets: list[ # Annotated[ # TODO[MS]: check if the double annotation is the best way to do this + Annotated[ + str, + AfterValidator(check_dot_notation), + AfterValidator(check_data_frame_as_target_argument), + Field(description="Targets in the form of `.`."), + ] + ] # , + # AfterValidator(check_duplicate_parameter_target), + # ] selector: SelectorType show_in_url: bool = Field( default=False, @@ -185,7 +185,16 @@ def pre_build(self): # pydantic validator like `check_dot_notation` on the `self.targets` again. # We do the update to ensure that `self.targets` is consistent with the targets passed to `_parameter`. self.targets.extend(list(filter_targets)) - self.selector.actions = [_parameter(id=f"{PARAMETER_ACTION_PREFIX}_{self.id}", targets=self.targets)] + self.selector.actions = [ + _validate_with_tree_context( + _parameter( + id=f"{PARAMETER_ACTION_PREFIX}_{self.id}", + targets=self.targets, + ), + parent_model=self.selector, + field_name="actions", + ) + ] @_log_call def build(self): diff --git a/vizro-core/src/vizro/models/_dashboard.py b/vizro-core/src/vizro/models/_dashboard.py index 218053bb0..04d1b07d4 100644 --- a/vizro-core/src/vizro/models/_dashboard.py +++ b/vizro-core/src/vizro/models/_dashboard.py @@ -35,11 +35,11 @@ from vizro.models._controls import Filter, Parameter from vizro.models._models_utils import _all_hidden, _log_call, warn_description_without_title from vizro.models._navigation._navigation_utils import _NavBuildType +from vizro.models._page import Page from vizro.models._tooltip import coerce_str_to_tooltip -from vizro.models.types import ControlType +from vizro.models.types import ControlType, make_discriminated_union if TYPE_CHECKING: - from vizro.models import Page from vizro.models._page import _PageBuildType logger = logging.getLogger(__name__) @@ -100,18 +100,21 @@ class Dashboard(VizroBaseModel): """ - pages: list[Page] + type: Literal["dashboard"] = "dashboard" + pages: list[make_discriminated_union(Page)] theme: Literal["vizro_dark", "vizro_light"] = Field( default="vizro_dark", description="Theme to be applied across dashboard. Defaults to `vizro_dark`." ) navigation: Annotated[ - Navigation | None, AfterValidator(set_navigation_pages), Field(default=None, validate_default=True) + make_discriminated_union(Navigation) | None, + AfterValidator(set_navigation_pages), + Field(default=None, validate_default=True), ] title: str = Field(default="", description="Dashboard title to appear on every page on top left-side.") # TODO: ideally description would have json_schema_input_type=str | Tooltip attached to the BeforeValidator, # but this requires pydantic >= 2.9. description: Annotated[ - Tooltip | None, + make_discriminated_union(Tooltip) | None, BeforeValidator(coerce_str_to_tooltip), AfterValidator(warn_description_without_title), Field( diff --git a/vizro-core/src/vizro/models/_fake_vizro/actions.py b/vizro-core/src/vizro/models/_fake_vizro/actions.py new file mode 100644 index 000000000..8032c96f4 --- /dev/null +++ b/vizro-core/src/vizro/models/_fake_vizro/actions.py @@ -0,0 +1,14 @@ +"""Actions module - creates circular dependency with models.py. + +models.py imports ExportDataAction from here +actions.py imports VizroBaseModel from models.py +→ CIRCULAR DEPENDENCY! +""" + +from vizro.models._fake_vizro.models import VizroBaseModel + + +class ExportDataAction(VizroBaseModel): + """Export data action that inherits from VizroBaseModel (defined in models.py).""" + + format: str = "csv" diff --git a/vizro-core/src/vizro/models/_fake_vizro/app.py b/vizro-core/src/vizro/models/_fake_vizro/app.py new file mode 100644 index 000000000..937de4bf0 --- /dev/null +++ b/vizro-core/src/vizro/models/_fake_vizro/app.py @@ -0,0 +1,136 @@ +"""Example app.py to play with the fake vizro models.""" + +from typing import Literal + +from vizro.models._fake_vizro.models import ( + Action, + Card, + Component, + Dashboard, + ExportDataAction, + Graph, + Page, + VizroBaseModel, +) + + +# User-defined custom components (realistic usage) +class CustomCard(VizroBaseModel): + """User-defined custom component - demonstrates tree building issue.""" + + type: Literal["custom_component"] = "custom_component" + title: str + + +class CustomPage(Page): + # Allow int + title: int + + +class CustomPageBase(VizroBaseModel): + title: int + components: list[Graph | Card] + + +class CustomGraph(Graph): + figure: int + + +class CustomGraphBase(VizroBaseModel): + figure: int + + +dashboard = Dashboard( + pages=[ + Page(title="page_1", components=[Component(x="c1")]), + Page(title="page_2", components=[Graph(figure="c3", actions=[Action(action="action1")])]), + Page( + title="page_3", + components=[Graph(figure="c3", actions=[Action(action="export", function=ExportDataAction(format="csv"))])], + ), + ] +) + +# dashboard_data = { +# "pages": [ +# { +# "title": "page_1", +# "components": [ +# {"type": "graph", "figure": "c1"}, +# {"type": "card", "text": "some text for card"}, +# ], +# }, +# { +# "title": "page_2", +# "components": [ +# {"type": "graph", "figure": "c3"}, +# # You can add another Card or Graph here if desired +# ], +# }, +# ], +# } + +# Minimal case demonstrating custom component tree building issue +dashboard_with_custom = Dashboard( + pages=[ + Page( + title="test_page", + components=[ + Graph(figure="test_figure", actions=[]), # Built-in component - appears in tree + Card(text="test_card"), # Built-in component - appears in tree + CustomCard(title="custom_card"), # Custom component - MISSING from tree + ], + ), + ], +) + +dashboard_with_custom = Dashboard.model_validate(dashboard_with_custom, context={"build_tree": True}) +print("\n" + "=" * 80) +print("Tree output - CustomCard should now appear:") +print("=" * 80) +dashboard_with_custom._tree.print(repr="{node.kind} -> {node.data.type} (id={node.data.id})") +print("\n" + "=" * 80) + +# dashboard = Dashboard.model_validate(dashboard, context={"build_tree": True}) +# print("--------------------------------") +# for page in dashboard.pages: +# page.pre_build() +# Notes +# Any additional model validate erases private property of tree, but why does it +# NOT erase the _parent_model attribute +# print("--------------------------------") +# dashboard = Dashboard.model_validate(dashboard) + +# comp = Component.from_pre_build( +# {"x": [SubComponent(y="new c3"), SubComponent(y="another new c3")]}, dashboard.pages[0], "components" +# ) +# dashboard._tree.print(repr="{node.data.type} (id={node.data.id})") + +# dashboard.pages[0]._tree.print() # repr="{node.data.type} (id={node.data.id})" +# dashboard.pages[0]._tree.print() # repr="{node.data.type} (id={node.data.id})" +# print("---") +# print(dashboard.pages[0].components[0].actions[0]) +# print(dashboard.pages[0].components[0].actions[0]._parent_model) + +# JSON Schema (commented out to focus on tree building issue) +# graph = Graph(id="graph-id", figure="a", actions=[Action(id="action-id", action="a")]) +# graph = Graph.model_validate(graph) +# print(json.dumps(graph.model_dump(), indent=2)) +# print(json.dumps(ExportDataAction.model_json_schema(), indent=2)) +# print("=" * 100) +# print(json.dumps(Card.model_json_schema(), indent=2)) +# print("=" * 100) +# print(json.dumps(Dashboard.model_json_schema(), indent=2)) +# print("=" * 100) +# print(json.dumps(SubComponent.model_json_schema(), indent=2)) +# print("=" * 100) +# ea = ExportDataAction(format="csv") +# print(json.dumps(ea.model_dump(), indent=2)) + + +""" + + +Run this file to see the error: + $ hatch run python src/vizro/models/_fake_vizro/app.py +""" diff --git a/vizro-core/src/vizro/models/_fake_vizro/models/__init__.py b/vizro-core/src/vizro/models/_fake_vizro/models/__init__.py new file mode 100644 index 000000000..123ac2e21 --- /dev/null +++ b/vizro-core/src/vizro/models/_fake_vizro/models/__init__.py @@ -0,0 +1,34 @@ +from vizro.models._fake_vizro.models.models import ( + Action, + Card, + Component, + Container, + Dashboard, + Graph, + Page, + SubComponent, + Tabs, + VizroBaseModel, +) + +__all__ = [ + "Action", + "Card", + "Component", + "Container", + "Dashboard", + "Graph", + "Page", + "SubComponent", + "Tabs", + "VizroBaseModel", +] + +# To resolve ForwardRefs we need to import ExportDataAction (similar to vizro.models.__init__.py) +# Import after models to avoid circular import +from vizro.models._fake_vizro.actions import ExportDataAction + +# Rebuild all models to resolve forward references +# Below we see that order matters, and while ExportDataAction now builds properly, not all models have a correct schema +for model in ["ExportDataAction", *__all__]: + globals()[model].model_rebuild(force=True) diff --git a/vizro-core/src/vizro/models/_fake_vizro/models/models.py b/vizro-core/src/vizro/models/_fake_vizro/models/models.py new file mode 100644 index 000000000..4f485e9ae --- /dev/null +++ b/vizro-core/src/vizro/models/_fake_vizro/models/models.py @@ -0,0 +1,481 @@ +## Notes from A +"""Fake Vizro models. + +This file illustrates how we can do revalidate_instances="always" and convert every field that's a Vizro model into a +discriminated union even if it has only one type to begin with (e.g. pages: list[Page]). +We forget about add_type and trying to generate a "correct" schema. Anything that's a custom model must stay as it is +and so be validated as Any. This sounds bad but is actually ok by me since the custom model itself performs the validation +it needs. NO - SEE BELOW AAARGH, this is probably not ok :( But I think we can figure it out. +Advantages: +- we keep revalidate_instances="always", so model manager registration works and it feels correct +- no need for add_type anywhere any more! In future we can just remove this entirely. Maybe we can come up with a better +mechanism for modifying the schema in future. +- a custom component can always (a) subclass an existing model in the discriminated union or (b) directly subclass VizroBaseModel. + Previously (a) and (b) worked for discriminated union fields (e.g. components: list[ComponentsType] but only + (a) worked for non-discriminated union fields (e.g. pages: list[Page]). +- hopefully not breaking in any way. +- no need for the stupid type field in all models (this is handled automatically through discriminator function). + We now rely on the class name to say whether it's a custom component or not rather than manually supplying the type. + We could keep with the current type field system if we wanted to but I don't see any adnnatage in doing so. +Disadvantages: +- feels very heavy handed/ugly - JSON schema becomes a mess (though arguably more accurate), API docs get messier +- AAAARGH just realised another problem that's obvious and I didn't think of before. As soon as we have Any then you +could e.g. do Dashboard(pages=[graph]) and validation would pass when it shouldn't. This seems very bad for +Vizro-MCP and life in general. How can we prevent this but still allow custom components to be injected anywhere? Do we +force custom components to have type explicitly specified as "custom_component" instead of just taking it from the class name? +Or we could maintain a list of built-in tags and throw error if the tag matches one of those when it shouldn't to stop +e.g. Graph being used as a Page. Probably the explicit type being specified is best. We can manage this by +distinguishing our built-in models from user-created ones e.g. just by assigning all of ours an explicit type that +custom components wouldn't specify or checking the import path, etc. So overall this is probably not disastrous, +we should be able to validate well enough. The below code doesn't demonstrate this and will work for Dashboard( +pages=[graph]). + +Notes: +- we don't have (or arguably need) a way of updating the schema at all. While you can't add specific types to the schema +any more, arguably the "custom" placeholder is just a general "injection point" for arbitrary types (a bit like our +extra field). +- there's no way to do custom component from YAML at all. Previously this was possible but completely undocumented. +Remaining challenges: +- will the copy on model revalidate cause us real problems in the model manager? As in last post of +https://github.com/pydantic/pydantic/issues/7608. Not worried about tests breaking so long as things work in practice +(and we understand why) since we can fix tests later. +- how to solve the new AAAARGH above, but I have some ideas here so don't worry about it too much. I am pretty +confident this is not a show-stopper and I can handle it when I'm back. +Alternative solution: +- drop revalidate_instances="always", but then need to go back to figure out how to handle model manager in a way that +builds tree correctly and avoids global variable. If we do go down that route then maybe from_attributes would be +useful for DashboardProxy.model_validate(Dashboard, from_attributes=True). See https://github.com/pydantic/pydantic/discussions/8933 +IMPORTANT IDEA we shouldn't forget: +- if/when we have dashboard = model_validate(dashboard) in Vizro.build, the argument dashboard: Dashboard can be much more general +i.e. it could also be dict. This would make loading from yaml/json more direct which would be nice. We could also +introduce and argument json=True/False or context (to match pydantic), which is passed through to the pydantic +context (could even use model_validate_json instead?). This would allow us to handle differently the load from json vs. +Python configuration: + - some things are possible from Python but not JSON + - CapturedCallable parsing can be split more cleaning depending on pydantic context +- We can introduce this argument as mandatory for JSON and it's non breaking. Or we could just decide based on +whether dashboard is instance of Dashboard whether to parse and JSON or as Python. This would be breaking for current +examples but we could change to this behaviour in breaking release. + +Additional notes MS: +- using SkipJsonSchema seems to be a good idea, as it cleans up the schema, which should not care about arbitrary Python extensions. +- using type = "" is not a good idea because it will hide the default from the JSON schema, which may (although not tested) +seriously confuse LLMs +- if __pydantic_init_subclass__ is not causing any trouble, then this might be the better solution. EDIT: it is causeing trouble!! + + +""" + +from __future__ import annotations + +import random +import re +import uuid +from collections.abc import Mapping +from types import SimpleNamespace +from typing import TYPE_CHECKING, Annotated, Any, Literal, Self, Union + +from nutree.typed_tree import TypedTree +from pydantic import ( + AfterValidator, + BaseModel, + ConfigDict, + Discriminator, + Field, + ModelWrapValidatorHandler, + PrivateAttr, + Tag, + ValidatorFunctionWrapHandler, + conlist, + field_validator, + model_validator, +) +from pydantic.json_schema import SkipJsonSchema +from pydantic_core.core_schema import ValidationInfo + +rd = random.Random(0) + +# Forward reference setup - creates circular dependency with actions.py +# Using TYPE_CHECKING avoids circular import at runtime, but causes forward ref issue +if TYPE_CHECKING: + from vizro.models._fake_vizro.actions import ExportDataAction +# Don't define ExportDataAction at runtime - this will cause PydanticUndefinedAnnotation + +from vizro.models._base import _validate_with_tree_context + + +# Written by ChatGPT +def camel_to_snake(name): + # Add underscores before uppercase letters, then lowercase everything + s1 = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", name) + return re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1).lower() + + +def make_discriminated_union(*args): + # Build discriminated union out of types in args. Tags are just the snake case version of the class names. + # Tag "custom_component" must validate as Any to keep its custom class. + builtin_tags = [camel_to_snake(T.__name__) for T in args] + types = [Annotated[T, Tag(builtin_tag)] for T, builtin_tag in zip(args, builtin_tags)] + types.append(SkipJsonSchema[Annotated[Any, Tag("custom_component")]]) + + # print(types) + + # With a proper type field established, we could go back to a normal discriminator on this field + # but this has the other consequence of needing a work-around for the stack mechanism of the model manager + # see former implementation here: https://github.com/McK-Internal/vizro-internal/issues/2273 + def discriminator(model): + if isinstance(model, dict): + # YAML configuration where no custom type possible + if len(builtin_tags) == 1: + # Fake discriminated union where there's only one option. + # Coerce to that model (could raise error if type specified and doesn't match if we wanted to, doesn't + # really matter) + return builtin_tags[0] + else: + # Real discriminated union case need a type to be specified + # If it's not specified then return None which wil raise a pydantic discriminated union error + return model.get("type", None) + elif isinstance(model, VizroBaseModel): + # Find tag of supplied model. + return model.type + else: + raise ValueError("something") + + return Annotated[Union[tuple(types)], Field(discriminator=Discriminator(discriminator))] # noqa: UP007 + + +class VizroBaseModel(BaseModel): + # Default type for base model. Subclasses should override with their specific Literal type. + # Custom components should set type: str = "custom_component" or type: Literal["custom_component"] = "custom_component" + type: Literal["vizro_base_model"] = Field(default="vizro_base_model") + model_config = ConfigDict( + extra="forbid", # Good for spotting user typos and being strict. + validate_assignment=True, # Run validators when a field is assigned after model instantiation. + revalidate_instances="always", + ) + + id: Annotated[ + str, + Field( + default_factory=lambda: str(uuid.UUID(int=rd.getrandbits(128))), + description="ID to identify model. Must be unique throughout the whole dashboard. " + "When no ID is chosen, ID will be automatically generated.", + validate_default=True, + ), + ] + _tree: TypedTree | None = PrivateAttr(None) # initialised in model_after + + @staticmethod + def _ensure_model_in_tree(model: VizroBaseModel, context: dict[str, Any]) -> VizroBaseModel: + """Revalidate a VizroBaseModel instance if it hasn't been added to the tree yet.""" + has_tree = hasattr(model, "_tree") + tree_is_none = getattr(model, "_tree", None) is None + if not has_tree or tree_is_none: + # Revalidate with build_tree context to ensure tree node is created + # NOTE: This can cause UniqueConstraintError if the model was already added to the tree + # during normal validation but _tree wasn't set yet (see Container in Tabs issue) + return model.__class__.model_validate(model, context=context) + return model + + @staticmethod + def _ensure_models_in_tree(validated_stuff: Any, context: dict[str, Any]) -> Any: + """Recursively ensure all VizroBaseModel instances in a structure are added to the tree.""" + if isinstance(validated_stuff, VizroBaseModel): + return VizroBaseModel._ensure_model_in_tree(validated_stuff, context) + elif isinstance(validated_stuff, list): + return [VizroBaseModel._ensure_models_in_tree(item, context) for item in validated_stuff] + elif isinstance(validated_stuff, Mapping) and not isinstance(validated_stuff, str): + # Note: str is a Mapping in Python, so we exclude it + return type(validated_stuff)( + {key: VizroBaseModel._ensure_models_in_tree(value, context) for key, value in validated_stuff.items()} + ) + return validated_stuff + + @field_validator("*", mode="wrap") + @classmethod + def build_tree_field_wrap( + cls, + value: Any, + handler: ValidatorFunctionWrapHandler, + info: ValidationInfo, + ) -> Any: + if info.context is not None and "build_tree" in info.context: + #### Field stack #### + if "id_stack" not in info.context: + info.context["id_stack"] = [] + if "field_stack" not in info.context: + info.context["field_stack"] = [] + if info.field_name == "id": + info.context["id_stack"].append(value) + else: + info.context["id_stack"].append(info.data.get("id", "no id")) + info.context["field_stack"].append(info.field_name) + #### Level and indentation #### + # indent = info.context["level"] * " " * 4 + # info.context["level"] += 1 + + #### Validation #### + validated_stuff = handler(value) + + if info.context is not None and "build_tree" in info.context: + #### Ensure VizroBaseModel instances are added to tree #### + # This handles the case where custom components match 'Any' in discriminated unions + # and might not go through full revalidation + # Note: field_stack and id_stack are still in place here (before the pop below) + # so build_tree_model_wrap will have the correct context + validated_stuff = VizroBaseModel._ensure_models_in_tree(validated_stuff, info.context) + + #### Field stack cleanup #### + # Pop after revalidation so the stacks are available during revalidation + info.context["id_stack"].pop() + info.context["field_stack"].pop() + + #### Level and indentation #### + # info.context["level"] -= 1 + # indent = info.context["level"] * " " * 4 + return validated_stuff + + @model_validator(mode="wrap") + @classmethod + def build_tree_model_wrap(cls, data: Any, handler: ModelWrapValidatorHandler[Self], info: ValidationInfo) -> Self: + #### ID #### + # Check Page ID case! + # Even change way we set it in Page (path logic etc - ideally separate PR) + # Leave page setting ID logic for now. + model_id = "UNKNOWN_ID" + if isinstance(data, dict): + if "id" not in data or data["id"] is None: + model_id = str(uuid.uuid4()) + data["id"] = model_id + # print(f" Setting id to {model_id}") + elif isinstance(data["id"], str): + model_id = data["id"] + # print(f" Using id {model_id}") + elif hasattr(data, "id"): + model_id = data.id + # print(f" Using id {model_id}") + else: + print("GRANDE PROBLEMA!!!") + + if info.context is not None and "build_tree" in info.context: + #### Level and indentation #### + if "level" not in info.context: + info.context["level"] = 0 + indent = info.context["level"] * " " * 4 + info.context["level"] += 1 + + #### Tree #### + print( + f"{indent}{cls.__name__} Before validation: {info.context['field_stack'] if 'field_stack' in info.context else 'no field stack'} with model id {model_id}" + ) + + if "parent_model" in info.context: + # print("IF PARENT MODEL") + info.context["tree"] = info.context["parent_model"]._tree + tree = info.context["tree"] + tree[info.context["parent_model"].id].add( + SimpleNamespace(id=model_id), kind=info.context["field_stack"][-1] + ) + # info.context["tree"].print() + elif "tree" not in info.context: + # print("NO PARENT MODEL, NO TREE") + tree = TypedTree("Root", calc_data_id=lambda tree, data: data.id) + tree.add(SimpleNamespace(id=model_id), kind="dashboard") # make this more general + info.context["tree"] = tree + # info.context["tree"].print() + else: + # print("NO PARENT MODEL, TREE") + tree = info.context["tree"] + # in words: add a node as children to the parent (so id one higher up), but add as kind + # the field in which you currently are + # ID STACK and FIELD STACK are different "levels" of the tree. + tree[info.context["id_stack"][-1]].add( + SimpleNamespace(id=model_id), kind=info.context["field_stack"][-1] + ) + # print("-" * 50) + + #### Validation #### + validated_stuff = handler(data) + + if info.context is not None and "build_tree" in info.context: + #### Replace placeholder nodes and propagate tree to all models #### + info.context["tree"][validated_stuff.id].set_data(validated_stuff) + validated_stuff._tree = info.context["tree"] + + #### Level and indentation #### + info.context["level"] -= 1 + indent = info.context["level"] * " " * 4 + print(f"{indent}{cls.__name__} After validation: {info.context['field_stack']}") + elif hasattr(data, "_tree") and data._tree is not None: + #### Revalidation case: model already has a tree (e.g., during assignment) #### + # Inherit the tree from the original instance + validated_stuff._tree = data._tree + # Update the tree node to point to the NEW validated instance + validated_stuff._tree[validated_stuff.id].set_data(validated_stuff) + print(f"--> Revalidation: Updated tree node for {validated_stuff.id} <--") + + return validated_stuff + + @classmethod + # AM NOTE: name TBC. + def from_pre_build(cls, data, parent_model, field_name): + # Note this always adds new models to the tree. It's not currently possible to replace or remove a node. + # It should work with any parent_model, but ideally we should only use it to make children of the calling + # model, so that parent_model=self in the call (where self isn't the created model instance, it's the calling + # model). + # Since we have revalidate_instances = "always", calling model_validate on a single model will also execute + # the validators on children models. + # MS: On children models we have the problem that they would not be correctly added + return cls.model_validate( + data, + context={ + "build_tree": True, + "parent_model": parent_model, + "field_stack": [field_name], + "id_stack": [parent_model.id], + }, + ) + + +def make_actions_chain(self): + for action in self.actions: + action.action = action.action + " (from make_actions_chain)" + action._parent_model = self + return self + + +class Action(VizroBaseModel): + type: Literal["action"] = "action" + action: str + # This field uses ExportDataAction - creates the forward reference issue + # Using string forward reference to trigger PydanticUndefinedAnnotation + function: str | ExportDataAction = "default" + + _parent_model: VizroBaseModel = PrivateAttr() + + +class Graph(VizroBaseModel): + type: Literal["graph"] = "graph" + figure: str + actions: list[Action] | None + + @model_validator(mode="after") + def _make_actions_chain(self): + return make_actions_chain(self) + + +class Card(VizroBaseModel): + type: Literal["card"] = "card" + text: str + + +class SubComponent(VizroBaseModel): + type: Literal["sub_component"] = "sub_component" + y: str = "subcomponent" + + +class Component(VizroBaseModel): + type: Literal["component"] = "component" + x: str | list[SubComponent] + + +class Container(VizroBaseModel): + type: Literal["container"] = "container" + title: str = "" + components: list[make_discriminated_union(Graph, Card, Component)] + + +def validate_tab_has_title(tab: Container) -> Container: + if not tab.title: + raise ValueError("`Container` must have a `title` explicitly set when used inside `Tabs`.") + return tab + + +class Tabs(VizroBaseModel): + type: Literal["tabs"] = "tabs" + tabs: conlist(Annotated[Container, AfterValidator(validate_tab_has_title)], min_length=1) # type: ignore[valid-type] + + +class Page(VizroBaseModel): + type: Literal["page"] = "page" + title: str + # Example of field where there's multiple options so it's already a real discriminated union. + components: list[make_discriminated_union(Graph, Card, Component, Tabs)] + + def pre_build(self): + print(f"Updating page {self.type}") + if isinstance(self.components[0], Component) and self.components[0].x == "c1": + self.components = [ + _validate_with_tree_context( + Component(x="new c1!!!"), # , SubComponent(y="another new c3") + parent_model=self, + field_name="components", + ) + ] + + +class Dashboard(VizroBaseModel): + type: Literal["dashboard"] = "dashboard" + # Example of field where there's really only one option that's built-in but we need to make it a discriminated union. + # This will make automated API docstrings much worse but we can explain it somewhere... + pages: list[make_discriminated_union(Page)] + + +if __name__ == "__main__": + """ +TODOs Maxi: +- test all combinations of yaml/python instantiations - DONE +- build in MM, see if pydantic_init_subclass is causing any problems - DONE +- check for model copy, do we loose private attributes still? Does it matter? - DONE +- check for json schema, does it look as nice as before? - DONE +- serialization/deserialization - DONE +- NEW: circular deps issue (see below) - DONE +- custom components do not end up in the tree - DONE +- containers in tabs seem to cause problems - need to investigate - DONE (was due to private attributes being lost, this + is not implemented in fake vizro) + +NOT FULLY RESOLVED +- what if we want to add normal component to other fields? (happens a lot!) - just use normal add_type? +==> This may run into the usual revalidate_instances problems! We may need to good schema modification function +==> after all!! +- check if pre-build needs to overwrite/delete models +- check if we ever need to add sub models in pre-build, so far it only works for single model +- how much to we need to care about idempotency of validation? Is there a difference between pre and post +pre-build and/or pre and post tree building? +- we should also check if users can call validation with build_tree context? + + +------------------------------------------------------------------------------------------------------------------------ +TODOs after moving the real Vizro: +------------------------------------------------------------------------------------------------------------------------ +- trial the model manager with the real Vizro and tree - DONE (dev example now works, notebooks not yet) +- try NOT referring to model_manager in a few places (probably take one of each case in the MM summary I once wrote) +- sort out page validation +- then bring over tests, and fix existing unit tests +- fix Jupyter notebooks properly (ie the callback_context issue), and also check if we can run multiple dashboards in +parallel? + + +Circular dependency issue: +------------------- +Circular dependency: models.py ↔ actions.py +- models.py needs ExportDataAction for type annotation +- actions.py needs VizroBaseModel from models.py to inherit + +The Problem: +- When Action class is defined, __pydantic_init_subclass__ runs +- It calls model_rebuild(force=True) +- Pydantic tries to evaluate Union[str, "ExportDataAction"] +- ExportDataAction not in namespace → PydanticUndefinedAnnotation + +Resolution attempts: +- many unstable solutions suggested by Claude, did not try them all +- since we rebuild the models in __init__.py, we can just import ExportDataAction after the models have been rebuilt +- HOWEVER, this still creates incomplete schemas (some $defs in models do not update), as Vizro is highly hierarchical, +so MRO matters, and the order of resolving models needs to be carefully considered (essentially the old add_type problem) +See also: https://docs.pydantic.dev/latest/internals/resolving_annotations/#limitations-and-backwards-compatibility-concerns + +==> Using __pydantic_init_subclass__ is not a viable solution if we want the schema of every model to be correct. +==> SOLUTION: remove __pydantic_init_subclass__ and use the new (old) system where we explicitly define types. + +""" diff --git a/vizro-core/src/vizro/models/_navigation/_navigation_utils.py b/vizro-core/src/vizro/models/_navigation/_navigation_utils.py index 4bf4afcb6..012750aec 100644 --- a/vizro-core/src/vizro/models/_navigation/_navigation_utils.py +++ b/vizro-core/src/vizro/models/_navigation/_navigation_utils.py @@ -56,8 +56,6 @@ def _resolve_list_of_page_references( return unknown_pages, validated_list -# TODO[MS]: This will need to move to pre-build in next PR - hopefully there is no problems -# introduced with handling things this way def _validate_pages(pages: NavPagesType) -> NavPagesType: """Reusable validator to check if provided Page titles exist as registered pages.""" from vizro.models import Page diff --git a/vizro-core/src/vizro/models/_navigation/accordion.py b/vizro-core/src/vizro/models/_navigation/accordion.py index 7f75998d7..e216a9f56 100644 --- a/vizro-core/src/vizro/models/_navigation/accordion.py +++ b/vizro-core/src/vizro/models/_navigation/accordion.py @@ -3,13 +3,12 @@ import dash_bootstrap_components as dbc from dash import get_relative_path -from pydantic import AfterValidator, BeforeValidator, Field +from pydantic import BeforeValidator, Field from vizro._constants import ACCORDION_DEFAULT_TITLE from vizro.managers._model_manager import model_manager from vizro.models import VizroBaseModel from vizro.models._models_utils import _log_call -from vizro.models._navigation._navigation_utils import _validate_pages from vizro.models.types import ModelID @@ -37,11 +36,15 @@ class Accordion(VizroBaseModel): str, list[ModelID], # TODO[MS]:this is the type after validation, but the type before validation is NavPagesType ], - AfterValidator(_validate_pages), BeforeValidator(coerce_pages_type), Field(default={}, description="Mapping from name of a pages group to a list of page IDs/titles."), ] + # @_log_call + # def pre_build(self): + # TODO[MS]: we may need to validate pages here? + # _validate_pages(self.pages) + @_log_call def build(self, *, active_page_id=None): # Note build does not return _NavBuildType but just a single html.Div with id="nav-panel". diff --git a/vizro-core/src/vizro/models/_navigation/nav_bar.py b/vizro-core/src/vizro/models/_navigation/nav_bar.py index c987318b5..d33ecebc4 100644 --- a/vizro-core/src/vizro/models/_navigation/nav_bar.py +++ b/vizro-core/src/vizro/models/_navigation/nav_bar.py @@ -4,14 +4,15 @@ import dash_bootstrap_components as dbc from dash import html -from pydantic import AfterValidator, BeforeValidator, Field +from pydantic import BeforeValidator, Field from vizro.managers import model_manager from vizro.models import VizroBaseModel +from vizro.models._base import _validate_with_tree_context from vizro.models._models_utils import _log_call -from vizro.models._navigation._navigation_utils import _NavBuildType, _validate_pages +from vizro.models._navigation._navigation_utils import _NavBuildType from vizro.models._navigation.nav_link import NavLink -from vizro.models.types import ModelID +from vizro.models.types import ModelID, make_discriminated_union def coerce_pages_type(pages: Any) -> Any: @@ -37,24 +38,29 @@ class NavBar(VizroBaseModel): type: Literal["nav_bar"] = "nav_bar" pages: Annotated[ dict[str, list[ModelID]], - AfterValidator(_validate_pages), BeforeValidator(coerce_pages_type), Field(default={}, description="Mapping from name of a pages group to a list of page IDs/titles."), ] - items: list[NavLink] = [] + items: list[make_discriminated_union(NavLink)] = [] @_log_call def pre_build(self): from vizro.models import Page + # TODO[MS]: we may need to validate pages here? + # self.pages = _validate_pages(self.pages) self.items = self.items or [ - NavLink( - # If the group title is a page ID (as is the case if you do `NavBar(pages=["page_1_id", "page_2_id"])`, - # then we prefer to have the title rather than id of that page be used - label=cast(Page, model_manager[group_title]).title - if group_title in [page.id for page in model_manager._get_models(model_type=Page)] - else group_title, - pages=pages, + _validate_with_tree_context( + NavLink( + # If the group title is a page ID (as is the case if you do `NavBar(pages=["page_1_id", "page_2_id"])`, + # then we prefer to have the title rather than id of that page be used + label=cast(Page, model_manager[group_title]).title + if group_title in [page.id for page in model_manager._get_models(model_type=Page)] + else group_title, + pages=pages, + ), + parent_model=self, + field_name="items", ) for group_title, pages in self.pages.items() ] diff --git a/vizro-core/src/vizro/models/_navigation/nav_link.py b/vizro-core/src/vizro/models/_navigation/nav_link.py index ba58e0a99..2762cc06e 100644 --- a/vizro-core/src/vizro/models/_navigation/nav_link.py +++ b/vizro-core/src/vizro/models/_navigation/nav_link.py @@ -1,7 +1,7 @@ from __future__ import annotations import itertools -from typing import Annotated, cast +from typing import Annotated, Literal, cast import dash_bootstrap_components as dbc from dash import get_relative_path, html @@ -9,6 +9,7 @@ from vizro.managers._model_manager import model_manager from vizro.models import VizroBaseModel +from vizro.models._base import _validate_with_tree_context from vizro.models._models_utils import _log_call, validate_icon from vizro.models._navigation._navigation_utils import _validate_pages from vizro.models._navigation.accordion import Accordion @@ -28,7 +29,11 @@ class NavLink(VizroBaseModel): """ - pages: Annotated[NavPagesType, AfterValidator(_validate_pages), Field(default=[])] + type: Literal["nav_link"] = "nav_link" + pages: Annotated[ + NavPagesType, + Field(default=[]), + ] label: str = Field(description="Text description of the icon for use in tooltip.") icon: Annotated[ str, @@ -39,9 +44,13 @@ class NavLink(VizroBaseModel): @_log_call def pre_build(self): - from vizro.models._navigation.accordion import Accordion - - self._nav_selector = Accordion(pages=self.pages) # type: ignore[arg-type] + # TODO[MS]: Check validate pages properly + self.pages = _validate_pages(self.pages) + self._nav_selector = _validate_with_tree_context( + Accordion(pages=self.pages), + parent_model=self, + field_name="_nav_selector", + ) @_log_call def build(self, *, active_page_id=None): diff --git a/vizro-core/src/vizro/models/_navigation/navigation.py b/vizro-core/src/vizro/models/_navigation/navigation.py index 84b9b29f3..c7089469d 100644 --- a/vizro-core/src/vizro/models/_navigation/navigation.py +++ b/vizro-core/src/vizro/models/_navigation/navigation.py @@ -1,12 +1,13 @@ from __future__ import annotations -from typing import Annotated, cast +from typing import Annotated, Literal, cast import dash_bootstrap_components as dbc from dash import html -from pydantic import AfterValidator, Field +from pydantic import Field from vizro.models import VizroBaseModel +from vizro.models._base import _validate_with_tree_context from vizro.models._models_utils import _log_call from vizro.models._navigation._navigation_utils import _NavBuildType, _validate_pages from vizro.models._navigation.accordion import Accordion @@ -26,17 +27,22 @@ class Navigation(VizroBaseModel): """ - pages: Annotated[NavPagesType, AfterValidator(_validate_pages), Field(default=[])] + type: Literal["navigation"] = "navigation" + pages: Annotated[ + NavPagesType, + Field(default=[]), + ] nav_selector: NavSelectorType | None = None @_log_call def pre_build(self): - # Since models instantiated in pre_build do not themselves have pre_build called on them, we call it manually - # here. Note that not all nav_selectors have pre_build (Accordion does not). - self.nav_selector = self.nav_selector or Accordion() - self.nav_selector.pages = self.nav_selector.pages or self.pages - if hasattr(self.nav_selector, "pre_build"): - self.nav_selector.pre_build() + # TODO[MS]: Check validate pages properly + self.pages = _validate_pages(self.pages) + self.nav_selector = self.nav_selector or _validate_with_tree_context( + Accordion(pages=self.pages), + parent_model=self, + field_name="nav_selector", + ) @_log_call def build(self, *, active_page_id=None) -> _NavBuildType: diff --git a/vizro-core/src/vizro/models/_page.py b/vizro-core/src/vizro/models/_page.py index ec10f0316..53fe20c12 100644 --- a/vizro-core/src/vizro/models/_page.py +++ b/vizro-core/src/vizro/models/_page.py @@ -2,7 +2,7 @@ from collections.abc import Iterable from itertools import chain -from typing import Annotated, cast +from typing import Annotated, Literal, cast import dash_mantine_components as dmc from dash import ClientsideFunction, Input, Output, State, clientside_callback, dcc, html @@ -20,6 +20,7 @@ from vizro.managers import model_manager from vizro.managers._model_manager import FIGURE_MODELS from vizro.models import Filter, Parameter, Tooltip, VizroBaseModel +from vizro.models._base import _validate_with_tree_context from vizro.models._grid import set_layout from vizro.models._models_utils import ( _all_hidden, @@ -29,7 +30,7 @@ make_actions_chain, warn_description_without_title, ) -from vizro.models.types import ActionsType, _IdProperty +from vizro.models.types import ActionsType, _IdProperty, make_discriminated_union from ._action._action import _BaseAction from ._tooltip import coerce_str_to_tooltip @@ -70,13 +71,14 @@ class Page(VizroBaseModel): """ # TODO[mypy], see: https://github.com/pydantic/pydantic/issues/156 for components field + type: Literal["page"] = "page" components: conlist(Annotated[ComponentType, BeforeValidator(check_captured_callable_model)], min_length=1) # type: ignore[valid-type] title: str = Field(description="Title of the `Page`") layout: Annotated[LayoutType | None, AfterValidator(set_layout), Field(default=None, validate_default=True)] # TODO: ideally description would have json_schema_input_type=str | Tooltip attached to the BeforeValidator, # but this requires pydantic >= 2.9. description: Annotated[ - Tooltip | None, + make_discriminated_union(Tooltip) | None, BeforeValidator(coerce_str_to_tooltip), AfterValidator(warn_description_without_title), Field( @@ -105,10 +107,10 @@ def validate_path(self): new_path = clean_path(self.title, "-_") # Check for duplicate path - will move to pre_build in next PR - for page in cast(Iterable[Page], model_manager._get_models(Page)): - # Need to check for id equality to avoid checking the same page against itself - if not self.id == page.id and new_path == page.path: - raise ValueError(f"Path {new_path} cannot be used by more than one page.") + # for page in cast(Iterable[Page], model_manager._get_models(Page)): + # # Need to check for id equality to avoid checking the same page against itself + # if not self.id == page.id and new_path == page.path: + # raise ValueError(f"Path {new_path} cannot be used by more than one page.") # We should do self.path = new_path but this leads to a recursion error. The below is a workaround # until the pydantic bug is fixed. See https://github.com/pydantic/pydantic/issues/6597. @@ -150,7 +152,16 @@ def pre_build(self): targets = figure_targets + filter_targets if targets: - self.actions = [_on_page_load(id=f"{ON_PAGE_LOAD_ACTION_PREFIX}_{self.id}", targets=targets)] + self.actions = [ + _validate_with_tree_context( + _on_page_load( + id=f"{ON_PAGE_LOAD_ACTION_PREFIX}_{self.id}", + targets=targets, + ), + parent_model=self, + field_name="actions", + ) + ] # Convert generator to list as it's going to be iterated multiple times. # Use "root_model=self" as controls can be defined inside a "Container.controls" under the "Page.components". diff --git a/vizro-core/src/vizro/models/_tooltip.py b/vizro-core/src/vizro/models/_tooltip.py index 11fe72a5d..757ef50c7 100644 --- a/vizro-core/src/vizro/models/_tooltip.py +++ b/vizro-core/src/vizro/models/_tooltip.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Annotated, Any +from typing import Annotated, Any, Literal import dash_bootstrap_components as dbc from dash import dcc, html @@ -53,6 +53,7 @@ class Tooltip(VizroBaseModel): ``` """ + type: Literal["tooltip"] = "tooltip" text: str = Field( description="Markdown string for text shown when hovering over the icon. Should adhere to the CommonMark Spec." ) diff --git a/vizro-core/src/vizro/models/types.py b/vizro-core/src/vizro/models/types.py index c37a377df..1b8476f61 100644 --- a/vizro-core/src/vizro/models/types.py +++ b/vizro-core/src/vizro/models/types.py @@ -5,12 +5,13 @@ # ruff: noqa: F821 import functools import inspect +import sys import warnings from collections import OrderedDict from collections.abc import Callable from contextlib import contextmanager from datetime import date -from typing import Annotated, Any, Literal, Protocol, TypeAlias, cast, runtime_checkable +from typing import Annotated, Any, Literal, Protocol, TypeAlias, Union, cast, runtime_checkable import plotly.io as pio import pydantic_core as cs @@ -30,6 +31,27 @@ from vizro.charts._charts_utils import _DashboardReadyFigure +if sys.version_info >= (3, 10): + from typing import TypeAlias +else: + from typing import TypeAlias + + +def camel_to_snake(name: str) -> str: + """Convert camelCase to snake_case. + + Args: + name: CamelCase string to convert + + Returns: + snake_case string + """ + import re + + # Add underscores before uppercase letters, then lowercase everything + s1 = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", name) + return re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1).lower() + def _get_layout_discriminator(layout: Any) -> str | None: """Helper function for callable discriminator used for LayoutType.""" @@ -77,6 +99,70 @@ def _get_action_discriminator(action: Any) -> str | None: return getattr(action, "type", None) +def _get_type_discriminator(model: Any, builtin_tags: list[str] | None = None) -> str | None: + """Shared discriminator function for extracting the 'type' field from models or dicts. + + Used by both make_discriminated_union (for single-type and multi-type unions) and + multi-type unions in types.py (ComponentType, SelectorType, ControlType, etc.). + + Args: + model: Model instance, dict, or other object to extract type from + builtin_tags: Optional list of builtin tags. If provided and length is 1, + enables single-type coercion (returns the tag even if type is not specified). + + Returns: + The type string or None if not found + """ + if isinstance(model, dict): + # YAML configuration where no custom type possible + if builtin_tags is not None and len(builtin_tags) == 1: + # Fake discriminated union where there's only one option. + # Coerce to that model (could raise error if type specified and doesn't match if we wanted to, doesn't + # really matter) + return builtin_tags[0] + else: + # Real discriminated union case need a type to be specified + # If it's not specified then return None which will raise a pydantic discriminated union error + return model.get("type", None) + elif hasattr(model, "type") and hasattr(model, "id"): + # Find tag of supplied model (check for VizroBaseModel instance). + return model.type + else: + raise ValueError("something") + + +# Helper function wrapper for multi-type unions that don't need single-type coercion +def _get_multi_type_discriminator(model: Any) -> str | None: + """Wrapper for _get_type_discriminator for multi-type unions. + + Used for SelectorType, _FormComponentType, ControlType, ComponentType, and NavSelectorType. + Passes None for builtin_tags to disable single-type coercion. + """ + return _get_type_discriminator(model, builtin_tags=None) + + +def make_discriminated_union(*args): + """Build discriminated union out of types in args. + + Tags are just the snake case version of the class names. + Tag "custom_component" must validate as Any to keep its custom class. + + Args: + *args: Types to include in the discriminated union + + Returns: + Annotated union with discriminator field + """ + builtin_tags = [camel_to_snake(T.__name__) for T in args] + types = [Annotated[T, Tag(builtin_tag)] for T, builtin_tag in zip(args, builtin_tags)] + types.append(SkipJsonSchema[Annotated[Any, Tag("custom_component")]]) + + def discriminator(model): + return _get_type_discriminator(model, builtin_tags) + + return Annotated[Union[tuple(types)], Field(discriminator=Discriminator(discriminator))] # noqa: UP007 + + def _clean_module_string(module_string: str) -> str: from vizro.models._models_utils import REPLACEMENT_STRINGS @@ -282,7 +368,7 @@ def _function(self): @classmethod def _validate_captured_callable( cls, - captured_callable_config: Union[dict[str, Any], _SupportsCapturedCallable, CapturedCallable], + captured_callable_config: dict[str, Any] | _SupportsCapturedCallable | CapturedCallable, json_schema_extra: _JsonSchemaExtraType, allow_undefined_captured_callable: list[str], ): @@ -316,10 +402,10 @@ def _core_validation(value: Any): @classmethod def _parse_json( cls, - captured_callable_config: Union[_SupportsCapturedCallable, CapturedCallable, dict[str, Any]], + captured_callable_config: _SupportsCapturedCallable | CapturedCallable | dict[str, Any], json_schema_extra: _JsonSchemaExtraType, allow_undefined_captured_callable: list[str], - ) -> Union[CapturedCallable, _SupportsCapturedCallable]: + ) -> CapturedCallable | _SupportsCapturedCallable: """Parses captured_callable_config specification from JSON/YAML. If captured_callable_config is already _SupportCapturedCallable or CapturedCallable then it just passes through @@ -368,7 +454,7 @@ def _parse_json( @classmethod def _extract_from_attribute( - cls, captured_callable: Union[_SupportsCapturedCallable, CapturedCallable] + cls, captured_callable: _SupportsCapturedCallable | CapturedCallable ) -> CapturedCallable: """Extracts CapturedCallable from _SupportCapturedCallable (e.g. _DashboardReadyFigure). @@ -408,7 +494,7 @@ def _check_type( @staticmethod def _format_args( - args_for_repr: Optional[Union[list[Any], tuple[Any, ...]]] = None, arguments: Optional[dict[str, Any]] = None + args_for_repr: list[Any] | tuple[Any, ...] | None = None, arguments: dict[str, Any] | None = None ) -> str: """Format arguments for string representation.""" return ", ".join( @@ -642,29 +728,59 @@ class _OptionsDictType(TypedDict): # All the below types rely on models and so must use ForwardRef (i.e. "Checklist" rather than actual Checklist class). SelectorType = Annotated[ - "Checklist | DatePicker | Dropdown | RadioItems | RangeSlider | Slider | Switch", - Field(discriminator="type", description="Selectors to be used inside a control."), + Annotated["Checklist", Tag("checklist")] + | Annotated["DatePicker", Tag("date_picker")] + | Annotated["Dropdown", Tag("dropdown")] + | Annotated["RadioItems", Tag("radio_items")] + | Annotated["RangeSlider", Tag("range_slider")] + | Annotated["Slider", Tag("slider")] + | Annotated["Switch", Tag("switch")] + | SkipJsonSchema[Annotated[Any, Tag("custom_component")]], + Field( + discriminator=Discriminator(_get_multi_type_discriminator), + description="Selectors to be used inside a control.", + ), ] """Discriminated union. Type of selector to be used inside a control: [`Checklist`][vizro.models.Checklist], [`DatePicker`][vizro.models.DatePicker], [`Dropdown`][vizro.models.Dropdown], [`RadioItems`][vizro.models.RadioItems], [`RangeSlider`][vizro.models.RangeSlider], [`Slider`][vizro.models.Slider] or [`Switch`][vizro.models.Switch].""" _FormComponentType = Annotated[ - "SelectorType | Button | UserInput", - Field(discriminator="type", description="Components that can be used to receive user input within a form."), + SelectorType + | Annotated["Button", Tag("button")] + | Annotated["UserInput", Tag("user_input")] + | SkipJsonSchema[Annotated[Any, Tag("custom_component")]], + Field( + discriminator=Discriminator(_get_multi_type_discriminator), + description="Components that can be used to receive user input within a form.", + ), ] ControlType = Annotated[ - "Filter | Parameter", - Field(discriminator="type", description="Control that affects components on the page."), + Annotated["Filter", Tag("filter")] + | Annotated["Parameter", Tag("parameter")] + | SkipJsonSchema[Annotated[Any, Tag("custom_component")]], + Field( + discriminator=Discriminator(_get_multi_type_discriminator), + description="Control that affects components on the page.", + ), ] """Discriminated union. Type of control that affects components on the page: [`Filter`][vizro.models.Filter] or [`Parameter`][vizro.models.Parameter].""" ComponentType = Annotated[ - "AgGrid | Button | Card | Container | Figure | Graph | Text | Table | Tabs", + Annotated["AgGrid", Tag("ag_grid")] + | Annotated["Button", Tag("button")] + | Annotated["Card", Tag("card")] + | Annotated["Container", Tag("container")] + | Annotated["Figure", Tag("figure")] + | Annotated["Graph", Tag("graph")] + | Annotated["Text", Tag("text")] + | Annotated["Table", Tag("table")] + | Annotated["Tabs", Tag("tabs")] + | SkipJsonSchema[Annotated[Any, Tag("custom_component")]], Field( - discriminator="type", + discriminator=Discriminator(_get_multi_type_discriminator), description="Component that makes up part of the layout on the page.", ), ] @@ -680,7 +796,13 @@ class _OptionsDictType(TypedDict): "List of page IDs or a mapping from name of a group to a list of page IDs (for hierarchical sub-navigation)." NavSelectorType = Annotated[ - "Accordion | NavBar", Field(discriminator="type", description="Component for rendering navigation.") + Annotated["Accordion", Tag("accordion")] + | Annotated["NavBar", Tag("nav_bar")] + | SkipJsonSchema[Annotated[Any, Tag("custom_component")]], + Field( + discriminator=Discriminator(_get_multi_type_discriminator), + description="Component for rendering navigation.", + ), ] """Discriminated union. Type of component for rendering navigation: [`Accordion`][vizro.models.Accordion] or [`NavBar`][vizro.models.NavBar].""" diff --git a/vizro-core/tests/unit/vizro/models/_fake_vizro/models/test_models.py b/vizro-core/tests/unit/vizro/models/_fake_vizro/models/test_models.py new file mode 100644 index 000000000..f66fe9b6a --- /dev/null +++ b/vizro-core/tests/unit/vizro/models/_fake_vizro/models/test_models.py @@ -0,0 +1,552 @@ +"""Tests for fake Vizro models to verify custom component handling.""" + +from typing import Literal + +import pytest +from pydantic import ValidationError + +from vizro.models._fake_vizro.models import ( + Action, + Card, + Component, + Container, + Dashboard, + Graph, + Page, + Tabs, + VizroBaseModel, +) + +# Custom component classes for testing + + +######## Page ############ +class CustomPage(Page): + """Custom page that accepts int for title instead of str.""" + + type: Literal["custom_component"] = "custom_component" + title: int + + +class CustomPageBase(VizroBaseModel): + """Custom page component directly subclassing VizroBaseModel.""" + + type: Literal["custom_component"] = "custom_component" + title: int + components: list[Graph | Card] + + +######### Graph ############ +class CustomGraph(Graph): + """Custom graph that accepts int for figure instead of str.""" + + type: Literal["custom_component"] = "custom_component" + figure: int + + +class CustomGraph2(Graph): + """Custom graph that accepts int for figure instead of str.""" + + type: Literal["custom_component"] = "custom_component" + figure: int + + +class CustomGraphNoType(Graph): + """Custom graph that accepts int for figure instead of str.""" + + figure: int + + +class CustomGraphNoTypeUpwardsCompatible(Graph): + """Custom graph inheriting from Graph without type, but upwards compatible with Graph.""" + + @classmethod + def dummy_method(cls): + return "dummy" + + +class CustomGraphBase(VizroBaseModel): + """Custom graph component directly subclassing VizroBaseModel.""" + + type: Literal["custom_component"] = "custom_component" + figure: int + + +class CustomGraphBaseNoType(VizroBaseModel): + """Custom graph component directly subclassing VizroBaseModel but without type.""" + + figure: int + + +# Tests +class TestFakeVizroNormalInstantiation: + """Test normal (non-custom) component instantiation.""" + + def test_python_instantiation(self): + """Test normal Python instantiation with model objects.""" + graph = Graph(figure="test_figure", actions=[Action(action="a")]) + page = Page(title="Test Page", components=[graph]) + dashboard = Dashboard(pages=[page]) + dashboard = Dashboard.model_validate(dashboard) + + assert type(dashboard.pages[0]) is Page + assert type(dashboard.pages[0].components[0]) is Graph + + def test_yaml_dict_instantiation(self): + """Test normal YAML/dict instantiation.""" + graph_dict = {"figure": "test_figure", "type": "graph", "actions": [{"action": "a"}]} + page_dict = {"title": "Test Page", "components": [graph_dict]} + dashboard_dict = {"pages": [page_dict]} + dashboard = Dashboard.model_validate(dashboard_dict) + + assert type(dashboard.pages[0]) is Page + assert type(dashboard.pages[0].components[0]) is Graph + + +class TestFakeVizroCustomComponentSubclassSpecificModel: + """Test custom components that subclass specific models (Page, Graph).""" + + def test_custom_page_in_normal_field(self): + """Test custom component (subclass of Page) in normal field (pages).""" + custom_page = CustomPage(title=456, components=[Graph(figure="test_figure", actions=[Action(action="a")])]) + dashboard = Dashboard(pages=[custom_page]) + dashboard = Dashboard.model_validate(dashboard) + + assert type(dashboard.pages[0]) is CustomPage + + def test_custom_graph_in_discriminated_union_field(self): + """Test custom component (subclass of Graph) in discriminated union field (components).""" + custom_graph = CustomGraph(figure=123, actions=[Action(action="a")]) + page = Page(title="Test Page", components=[custom_graph]) + dashboard = Dashboard(pages=[page]) + dashboard = Dashboard.model_validate(dashboard) + + assert type(dashboard.pages[0].components[0]) is CustomGraph + + def test_custom_graph_no_type_in_discriminated_union_field(self): + """Test custom component (subclass of Graph) without explicit type field fails validation. + + When no type is specified, Pydantic uses the class name of the parent class, + which will cause pydantic to validated against that parent class rather than the custom component. + """ + custom_graph_no_type = CustomGraphNoType(figure=999, actions=[Action(action="a")]) + with pytest.raises(ValidationError, match="Input should be a valid string"): + Page(title="Test Page", components=[custom_graph_no_type]) + + def test_custom_graph_no_type_upwards_compatible_in_discriminated_union_field(self): + """Test custom component (subclass of Graph) without explicit type field but upwards compatible with Graph. + + When no type is specified, Pydantic uses the type of the parent class as type, + which will cause pydantic to validate against that parent class rather than the custom component. + If the custom component is upwards compatible with the parent class, it will be validated against + the parent class, losing any extra functionality defined in the custom component. + """ + custom_graph_no_type_upwards_compatible = CustomGraphNoTypeUpwardsCompatible( + figure="string", actions=[Action(action="a")] + ) + page = Page(title="Test Page", components=[custom_graph_no_type_upwards_compatible]) + assert type(page.components[0]) is Graph + + def test_multiple_custom_components_in_discriminated_union_field(self): + """Test multiple custom components in discriminated union field (components).""" + custom_graph_1 = CustomGraph(figure=123, actions=[Action(action="a")]) + custom_graph_2 = CustomGraph2(figure=456, actions=[Action(action="a")]) + page = Page(title="Test Page", components=[custom_graph_1, custom_graph_2]) + dashboard = Dashboard(pages=[page]) + dashboard = Dashboard.model_validate(dashboard) + + assert type(dashboard.pages[0].components[0]) is CustomGraph + assert type(dashboard.pages[0].components[1]) is CustomGraph2 + + +class TestFakeVizroCustomComponentSubclassVizroBaseModel: + """Test custom components that directly subclass VizroBaseModel.""" + + def test_custom_page_base_in_normal_field(self): + """Test custom component (subclass of VizroBaseModel) in normal field (pages).""" + custom_page_base = CustomPageBase( + title=789, components=[Graph(figure="test_figure", actions=[Action(action="a")])] + ) + dashboard = Dashboard(pages=[custom_page_base]) + dashboard = Dashboard.model_validate(dashboard) + + assert type(dashboard.pages[0]) is CustomPageBase + + def test_custom_graph_base_in_discriminated_union_field(self): + """Test custom component (subclass of VizroBaseModel) in discriminated union field (components).""" + custom_graph_base = CustomGraphBase(figure=999) + page = Page(title="Test Page", components=[custom_graph_base]) + dashboard = Dashboard(pages=[page]) + dashboard = Dashboard.model_validate(dashboard) + + assert type(dashboard.pages[0].components[0]) is CustomGraphBase + + def test_custom_graph_base_no_type_in_discriminated_union_field(self): + """Test custom component (subclass of VizroBaseModel) without explicit type field fails validation. + + Without a type field, Pydantic defaults to using the class name as the discriminator, + which won't match registered union types. + """ + custom_graph_base_no_type = CustomGraphBaseNoType(figure=999) + with pytest.raises(ValidationError, match="Input tag 'vizro_base_model' found"): + Page(title="Test Page", components=[custom_graph_base_no_type]) + + +class TestFakeVizroYAMLWithCustomComponent: + """Test that YAML/dict instantiation with custom component types is not supported.""" + + def test_yaml_with_custom_component_should_fail(self): + """Test that YAML with custom component type fails validation. + + Custom components must be instantiated as Python objects, not via YAML/dict, + because their types are not registered in the discriminated union. + """ + custom_graph_dict = {"figure": 123, "type": "custom_graph"} + page_dict = {"title": "Test Page", "components": [custom_graph_dict]} + dashboard_dict = {"pages": [page_dict]} + + with pytest.raises(ValidationError): + Dashboard.model_validate(dashboard_dict) + + +class TestFakeVizroValidationErrors: + """Test that invalid configurations raise validation errors.""" + + def test_wrong_model_in_pages_field_python(self): + """Test that using Graph instead of Page raises validation error in Python.""" + graph = Graph(figure="a", actions=[Action(action="a")]) + + with pytest.raises(ValidationError): + Dashboard(pages=[graph]) + + def test_wrong_model_in_pages_field_yaml(self): + """Test that using Graph instead of Page raises validation error in YAML/dict.""" + graph_dict = {"figure": "a", "type": "graph"} # Graph is not a Page + dashboard_dict = {"pages": [graph_dict]} + + with pytest.raises(ValidationError): + Dashboard.model_validate(dashboard_dict) + + +class TestFakeVizroLiteralType: + """Test that type field validation works correctly with Literal types.""" + + def test_literal_type_builtin_model(self): + """Test that built-in models enforce their Literal type value.""" + graph = Graph(figure="a", actions=[Action(action="a")]) + assert graph.type == "graph" + with pytest.raises(ValidationError): + Graph(figure="a", type="custom_component") + + def test_literal_type_custom_model(self): + """Test custom components with Literal type value enforce their Literal type value.""" + custom_graph = CustomGraph(figure=3, type="custom_component", actions=[Action(action="a")]) + assert custom_graph.type == "custom_component" + + with pytest.raises(ValidationError): + CustomGraph(figure=3, type="graph") + + def test_literal_type_custom_model_no_type(self): + """When no type is specified, Pydantic uses the type of the parent class as type.""" + custom_graph_no_type = CustomGraphNoType(figure=3, actions=[Action(action="a")]) + assert custom_graph_no_type.type == "graph" + with pytest.raises(ValidationError): + CustomGraphNoType(figure=3, type="custom_component") + + def test_literal_type_custom_model_base(self): + """Test custom model bases with Literal type value enforce their Literal type value.""" + custom_graph_base = CustomGraphBase(figure=3) + assert custom_graph_base.type == "custom_component" + with pytest.raises(ValidationError): + CustomGraphBase(figure=3, type="graph") + + def test_literal_type_custom_model_base_no_type(self): + """When no type is specified, Pydantic uses the type of the parent class as type. + + This model cannot be used anywhere in Vizro though, because it has no valid or custom_component type. + """ + custom_graph_base_no_type = CustomGraphBaseNoType(figure=3) + assert custom_graph_base_no_type.type == "vizro_base_model" + with pytest.raises(ValidationError): + CustomGraphBaseNoType(figure=3, type="graph") + + with pytest.raises(ValidationError): + CustomGraphBaseNoType(figure=3, type="custom_component") + + +@pytest.fixture +def dashboard_with_graph_and_action(): + """Fixture for a dashboard with a graph and an action.""" + return Dashboard(pages=[Page(title="Test Page", components=[Graph(figure="a", actions=[Action(action="a")])])]) + + +@pytest.fixture +def dashboard_with_graph_and_action_revalidated(dashboard_with_graph_and_action): + """Dashboard revalidated without tree context.""" + return Dashboard.model_validate(dashboard_with_graph_and_action) + + +@pytest.fixture +def dashboard_with_graph_and_action_revalidated_with_tree(dashboard_with_graph_and_action): + """Dashboard revalidated with build_tree context enabled.""" + return Dashboard.model_validate(dashboard_with_graph_and_action, context={"build_tree": True}) + + +@pytest.fixture +def dashboard_with_graph_and_action_revalidated_with_tree_revalidated( + dashboard_with_graph_and_action_revalidated_with_tree, +): + """Dashboard with tree revalidated again (tests tree persistence).""" + return Dashboard.model_validate(dashboard_with_graph_and_action_revalidated_with_tree) + + +@pytest.fixture +def dashboard(request): + """Parametrized fixture that returns the requested dashboard fixture.""" + return request.getfixturevalue(request.param) + + +@pytest.fixture +def dashboard_with_tree(request): + """Fixture for a dashboard with tree.""" + return request.getfixturevalue(request.param) + + +# Fixture name constants for parametrization - used to test different states of dashboard validation +DASHBOARDS_WITHOUT_TREE = [ + "dashboard_with_graph_and_action", + "dashboard_with_graph_and_action_revalidated", +] + +DASHBOARDS_WITH_TREE = [ + "dashboard_with_graph_and_action_revalidated_with_tree", + "dashboard_with_graph_and_action_revalidated_with_tree_revalidated", +] + + +class TestFakeVizroDashboardTreeCreation: + """Test tree creation using the validator approach.""" + + @pytest.mark.parametrize("dashboard", DASHBOARDS_WITHOUT_TREE, indirect=True) + def test_tree_creation_not_triggered(self, dashboard): + """Test tree creation is not triggered.""" + assert dashboard._tree is None + + # TODO: For real Vizro we need to check all three cases (normal, list, mapping) + def test_custom_component_in_tree(self): + """Test that custom components are added to the tree.""" + + class CustomCard(VizroBaseModel): + type: Literal["custom_component"] = "custom_component" + title: str + + dashboard = Dashboard.model_validate( + Dashboard( + pages=[Page(title="test", components=[Graph(figure="fig", actions=[]), CustomCard(title="custom")])] + ), + context={"build_tree": True}, + ) + assert dashboard.pages[0].components[1].id in {node.data.id for node in dashboard._tree} + + @pytest.mark.parametrize("dashboard_with_tree", DASHBOARDS_WITH_TREE, indirect=True) + def test_tree_creation_triggered(self, dashboard_with_tree): + """Test tree creation is triggered when build_tree context is provided. + + This test checks for a number of facts about the tree. + """ + assert dashboard_with_tree._tree is not None + + # 0. Check tree reference is shared across all models in the hierarchy + assert dashboard_with_tree._tree is dashboard_with_tree.pages[0]._tree + + # 1. Check root exists and has correct structure + assert dashboard_with_tree._tree.name == "Root" + + # 2. Check node count matches expected hierarchy + # Dashboard -> Page -> Graph -> Action = 4 nodes total + assert len(dashboard_with_tree._tree) == 4 + + # 3. Check tree depth + assert dashboard_with_tree._tree.calc_height() == 4 # Root -> Dashboard -> Page -> Graph -> Action + + # 4. Verify node kinds are correct (field names) + kind_checks = [ + (dashboard_with_tree.pages[0].id, "pages"), + (dashboard_with_tree.pages[0].components[0].id, "components"), + (dashboard_with_tree.pages[0].components[0].actions[0].id, "actions"), + ] + for model_id, expected_kind in kind_checks: + assert dashboard_with_tree._tree[model_id].kind == expected_kind + + # 5. Check all nodes have valid data and correspond to real model objects + models_to_check = [ + dashboard_with_tree, + dashboard_with_tree.pages[0], + dashboard_with_tree.pages[0].components[0], + dashboard_with_tree.pages[0].components[0].actions[0], + ] + for model in models_to_check: + node = dashboard_with_tree._tree[model.id] + assert isinstance(node.data, VizroBaseModel) + assert hasattr(node.data, "id") + assert node.data is model + + @pytest.mark.parametrize("dashboard", DASHBOARDS_WITHOUT_TREE + DASHBOARDS_WITH_TREE, indirect=True) + def test_private_attribute_parent_model(self, dashboard): + """Test private attribute _parent_model.""" + assert dashboard.pages[0].components[0].actions[0]._parent_model is dashboard.pages[0].components[0] + + +@pytest.fixture +def dashboard_with_component_for_pre_build(): + """Fixture for a dashboard with a component for pre-build.""" + return Dashboard(pages=[Page(title="Test Page", components=[Component(x="c1")])]) + + +class TestFakeVizroPreBuildTreeAddition: + """Test that tree nodes are properly updated when models are modified during pre-build.""" + + def test_pre_build_tree_addition(self, dashboard_with_component_for_pre_build): + """Test that pre-build modifications update tree node references correctly.""" + dashboard = Dashboard.model_validate(dashboard_with_component_for_pre_build, context={"build_tree": True}) + + for page in dashboard.pages: + page.pre_build() + + # Check component was changed by pre-build + new_component = dashboard.pages[0].components[0] + assert isinstance(new_component, Component) + assert new_component.x == "new c1!!!" + + # Check all tree nodes have valid data and correspond to real model objects + models_to_check = [ + dashboard, + dashboard.pages[0], + new_component, + ] + for model in models_to_check: + node = dashboard._tree[model.id] + assert isinstance(node.data, VizroBaseModel) + assert hasattr(node.data, "id") + assert node.data is model + + +class TestFakeVizroJSONSchema: + """Test that the JSON schema looks as expected.""" + + def test_json_schema_card(self): + """Test that the JSON schema for Card looks as expected, particular for the type field.""" + schema = Card.model_json_schema() + expected_schema = { + "additionalProperties": False, + "properties": { + "type": {"const": "card", "default": "card", "title": "Type", "type": "string"}, + "id": { + "description": ( + "ID to identify model. Must be unique throughout the whole dashboard. " + "When no ID is chosen, ID will be automatically generated." + ), + "title": "Id", + "type": "string", + }, + "text": {"title": "Text", "type": "string"}, + }, + "required": ["text"], + "title": "Card", + "type": "object", + } + assert schema == expected_schema + + def test_json_schema_page(self): + """Test that the JSON schema for Page looks as expected, particular for the type field and the discriminated union field.""" + schema = Page.model_json_schema() + # Remove all $defs before comparison + if "$defs" in schema: + schema.pop("$defs") + expected_schema = { + "additionalProperties": False, + "properties": { + "type": {"const": "page", "default": "page", "title": "Type", "type": "string"}, + "id": { + "description": ( + "ID to identify model. Must be unique throughout the whole dashboard. " + "When no ID is chosen, ID will be automatically generated." + ), + "title": "Id", + "type": "string", + }, + "title": {"title": "Title", "type": "string"}, + "components": { + "items": { + "oneOf": [ + {"$ref": "#/$defs/Graph"}, + {"$ref": "#/$defs/Card"}, + {"$ref": "#/$defs/Component"}, + ], + }, + "title": "Components", + "type": "array", + }, + }, + "required": ["title", "components"], + "title": "Page", + "type": "object", + } + assert schema == expected_schema + + +class TestFakeVizroSerialization: + """Test that the serialization works as expected. + + Note that we currently cannot get rid of the ID in serialization I think. + """ + + def test_serialization_graph(self): + """Test that the serialization for Graph works as expected.""" + graph = Graph(id="graph-id", figure="a", actions=[Action(id="action-id", action="a")]) + assert graph.model_dump() == { + "id": "graph-id", + "type": "graph", + "figure": "a", + "actions": [ + { + "id": "action-id", + "type": "action", + "action": "a (from make_actions_chain)", + "function": "default", + } + ], + } + + def test_serialization_without_id(self): + """Test that the serialization for Graph works as expected without id.""" + graph = Graph(id="graph-id", figure="a", actions=[Action(id="action-id", action="a")]) + assert graph.model_dump(exclude_unset=True) == { + "id": "graph-id", + "figure": "a", + "actions": [{"id": "action-id", "action": "a (from make_actions_chain)"}], + } + + +class TestFakeVizroContainerInTabs: + """Test that Container in Tabs causes UniqueConstraintError.""" + + def test_container_in_tabs_unique_constraint_error(self): + """Test that using a Container in Tabs causes UniqueConstraintError when building tree.""" + from nutree.common import UniqueConstraintError + + container = Container( + title="Tab I", + components=[ + Graph(figure="test_figure_1", actions=[Action(action="a")]), + Graph(figure="test_figure_2", actions=[Action(action="b")]), + ], + ) + + tabs = Tabs(tabs=[container]) + page = Page(title="Tabs", components=[tabs]) + dashboard = Dashboard(pages=[page]) + + # This should raise UniqueConstraintError + with pytest.raises(UniqueConstraintError, match="Node.data already exists in parent"): + Dashboard.model_validate(dashboard, context={"build_tree": True})