-
Notifications
You must be signed in to change notification settings - Fork 94
Add get_model_lineage_dev CLI tool
#420
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 6 commits
637e5ca
bc06c32
adf92cc
6d5db43
adeba32
c041c1a
d818a18
74f2309
e832496
9bcb837
16591c5
d49a815
5a48ff7
3de7eee
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,138 @@ | ||
| from __future__ import annotations | ||
| from typing import Literal | ||
|
|
||
| from pydantic import BaseModel, Field | ||
| from typing import Any | ||
|
|
||
|
|
||
| class Descendant(BaseModel): | ||
| model_id: str | ||
| children: list[Descendant] = Field(default_factory=list) | ||
|
|
||
|
|
||
| class Ancestor(BaseModel): | ||
| model_id: str | ||
| parents: list[Ancestor] = Field(default_factory=list) | ||
|
|
||
|
|
||
| class ModelLineage(BaseModel): | ||
| model_id: str | ||
| parents: list[Ancestor] = Field(default_factory=list) | ||
| children: list[Descendant] = Field(default_factory=list) | ||
|
|
||
| @classmethod | ||
| def from_manifest( | ||
| cls, | ||
| manifest: dict[str, Any], | ||
| model_id: str, | ||
| direction: Literal["parents", "children", "both"] = "both", | ||
| exclude_prefixes: tuple[str, ...] = ("test.", "unit_test."), | ||
| *, | ||
| recursive: bool = False, | ||
| ) -> ModelLineage: | ||
| """ | ||
| Build a ModelLineage instance from a dbt manifest mapping. | ||
| - manifest: dict containing at least 'parent_map' and/or 'child_map' | ||
| - model_id: the model id to start from | ||
| - recursive: whether to traverse recursively | ||
| - direction: one of 'parents', 'children', or 'both' | ||
| - exclude_prefixes: tuple of prefixes to exclude from descendants, defaults to ("test.", "unit_test.") | ||
| Descendants only. Give () to include all. | ||
| The returned ModelLineage contains lists of Ancestor and/or Descendant | ||
| objects. For compatibility with the previous implementation, recursive | ||
| traversal returns a flat list of Ancestor/Descendant nodes (no nested | ||
| parents/children relationships are constructed). | ||
| """ | ||
| parent_map = manifest.get("parent_map", {}) | ||
| child_map = manifest.get("child_map", {}) | ||
|
|
||
| parents: list[Ancestor] = [] | ||
| children: list[Descendant] = [] | ||
| model_id = get_uid_from_name(manifest, model_id) | ||
|
|
||
| if direction in ("both", "parents"): | ||
| if not recursive: | ||
| # direct parents only | ||
| for pid in parent_map.get(model_id, []): | ||
| parents.append(Ancestor.model_validate({"model_id": pid})) | ||
| else: | ||
| # Build nested ancestor trees. We prevent cycles using path tracking. | ||
| def _build_ancestor(node_id: str, path: set[str]) -> Ancestor: | ||
| if node_id in path: | ||
| # cycle detected, return node without parents | ||
VDFaller marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| return Ancestor.model_validate({"model_id": node_id}) | ||
| new_path = set(path) | ||
| new_path.add(node_id) | ||
| parents = [ | ||
| _build_ancestor(pid, new_path) | ||
| for pid in parent_map.get(node_id, []) | ||
| ] | ||
| return Ancestor.model_validate( | ||
| {"model_id": node_id, "parents": parents} | ||
| ) | ||
|
|
||
| for pid in parent_map.get(model_id, []): | ||
| parents.append(_build_ancestor(pid, {model_id})) | ||
|
|
||
| if direction in ("both", "children"): | ||
| if not recursive: | ||
| children = [ | ||
| Descendant.model_validate({"model_id": cid}) | ||
| for cid in child_map.get(model_id, []) | ||
| ] | ||
| else: | ||
| # Build nested descendant trees. Prevent cycles using path tracking. | ||
| def _build_descendant(node_id: str, path: set[str]) -> Descendant: | ||
|
||
| if node_id in path: | ||
| return Descendant.model_validate({"model_id": node_id}) | ||
| new_path = set(path) | ||
| new_path.add(node_id) | ||
| # exclude children with specified prefixes | ||
| new_children = [ | ||
| cid | ||
| for cid in child_map.get(node_id, []) | ||
| if not cid.startswith(exclude_prefixes) | ||
| ] | ||
|
|
||
| children = [ | ||
| _build_descendant(cid, new_path) for cid in new_children | ||
| ] | ||
| return Descendant.model_validate( | ||
| {"model_id": node_id, "children": children}, | ||
| context={"exclude_prefixes": exclude_prefixes}, | ||
| ) | ||
|
|
||
| for cid in [ | ||
| cid | ||
| for cid in child_map.get(model_id, []) | ||
| if not cid.startswith(exclude_prefixes) | ||
| ]: | ||
| children.append(_build_descendant(cid, {model_id})) | ||
| return cls( | ||
| model_id=model_id, | ||
| parents=parents, | ||
| children=children, | ||
| ) | ||
|
|
||
|
|
||
| def get_uid_from_name(manifest: dict[str, Any], model_id: str) -> str: | ||
| """ | ||
| Given a dbt manifest mapping and a model name, return the unique_id | ||
| corresponding to that model name, or None if not found. | ||
| """ | ||
| # using the parent and child map so it include sources/exposures | ||
| if model_id in manifest["child_map"] or model_id in manifest["parent_map"]: | ||
| return model_id | ||
| # fallback: look through eveything for the identifier | ||
DevonFulcher marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| for uid, node in manifest.get("nodes", {}).items(): | ||
| if node.get("name") == model_id: | ||
| return uid | ||
| for uid, source in manifest.get("sources", {}).items(): | ||
| if source.get("identifier") == model_id: | ||
| return uid | ||
| for uid, exposure in manifest.get("exposures", {}).items(): | ||
| if exposure.get("name") == model_id: | ||
| return uid | ||
| raise ValueError(f"Model name '{model_id}' not found in manifest.") | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,12 +1,14 @@ | ||
| import os | ||
| import subprocess | ||
| from collections.abc import Iterable, Sequence | ||
| from typing import Any, Literal | ||
|
|
||
| from mcp.server.fastmcp import FastMCP | ||
| from pydantic import Field | ||
|
|
||
| from dbt_mcp.config.config import DbtCliConfig | ||
| from dbt_mcp.dbt_cli.binary_type import get_color_disable_flag | ||
| from dbt_mcp.dbt_cli.models.lineage_types import ModelLineage | ||
| from dbt_mcp.prompts.prompts import get_prompt | ||
| from dbt_mcp.tools.definitions import ToolDefinition | ||
| from dbt_mcp.tools.register import register_tools | ||
|
|
@@ -176,6 +178,34 @@ def show( | |
| args.extend(["--output", "json"]) | ||
| return _run_dbt_command(args) | ||
|
|
||
| def _get_manifest() -> dict[str, Any]: | ||
VDFaller marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| """Helper function to load the dbt manifest.json file.""" | ||
| import json | ||
VDFaller marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| _run_dbt_command(["parse"]) # Ensure manifest is generated | ||
| cwd_path = config.project_dir if os.path.isabs(config.project_dir) else None | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Non-blocking: when the server starts up, we should make all paths absolute if they aren't already. |
||
| manifest_path = os.path.join(cwd_path or ".", "target", "manifest.json") | ||
| with open(manifest_path) as f: | ||
| manifest = json.load(f) | ||
| return manifest | ||
|
|
||
| def get_model_lineage_dev( | ||
| model_id: str, | ||
| direction: Literal["parents", "children", "both"] = "both", | ||
| exclude_prefixes: tuple[str, ...] = ("test.", "unit_test."), | ||
| *, | ||
| recursive: bool, | ||
| ) -> dict[str, Any]: | ||
|
Comment on lines
+193
to
+199
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Non-blocking: I think this should have essentially the same function signature as the other lineage tool: https://github.com/dbt-labs/dbt-mcp/pull/461/files#diff-6d91f0721d8dcde8199de504338811a7063757ec13f32eca508bfbc8b663a54bR390-R396
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'll make a separate PR once they're both in to align them. Cool? |
||
| manifest = _get_manifest() | ||
| model_lineage = ModelLineage.from_manifest( | ||
| manifest, | ||
| model_id, | ||
| direction=direction, | ||
| exclude_prefixes=exclude_prefixes, | ||
| recursive=recursive, | ||
| ) | ||
| return model_lineage.model_dump() | ||
|
|
||
| return [ | ||
| ToolDefinition( | ||
| fn=build, | ||
|
|
@@ -258,6 +288,17 @@ def show( | |
| idempotent_hint=True, | ||
| ), | ||
| ), | ||
| ToolDefinition( | ||
| name="get_model_lineage_dev", | ||
| fn=get_model_lineage_dev, | ||
| description=get_prompt("dbt_cli/get_model_lineage_dev"), | ||
| annotations=create_tool_annotations( | ||
| title="Get Model Lineage (Dev)", | ||
| read_only_hint=True, | ||
| destructive_hint=False, | ||
| idempotent_hint=True, | ||
| ), | ||
| ), | ||
| ] | ||
|
|
||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,26 @@ | ||
| get_model_lineage_dev | ||
|
|
||
| <instructions> | ||
| Retrieves the model lineage of a specific dbt model, it allows for upstream, downstream, or both. These are the models that depend on the specified model. | ||
|
|
||
| You can provide either a model_name or a uniqueId, if known, to identify the model. Using uniqueId is more precise and guarantees a unique match, which is especially useful when models might have the same name in different projects. | ||
| This specifically ONLY pulls from the local development manifest. If you want production lineage, use `get_model_children` or `get_model_parents` instead. | ||
| </instructions> | ||
|
|
||
| <parameters> | ||
| model_id: str => Either the uniqueId or the `identifier` of the dbt model to retrieve lineage for. | ||
| direction: Literal["parents", "children", "both"] = "both" => The direction of lineage to retrieve. "parents" for upstream models, "children" for downstream models, and "both" for both directions. | ||
| exclude_prefixes: tuple[str, ...] = ("test.", "unit_test."), => A tuple of prefixes to exclude from the lineage results. Assets with identifiers starting with any of these prefixes will be ignored. | ||
| recursive: bool = False => Whether to retrieve lineage recursively. If set to True, it will fetch all levels of lineage in the specified direction(s). | ||
| </parameters> | ||
|
|
||
| <examples> | ||
| 1. Getting children for a model by name: | ||
| get_model_lineage_dev(model_id="customer_orders", direction="children") | ||
|
|
||
| 2. Getting parents for a model by uniqueId (more precise): | ||
| get_model_lineage_dev(model_id="model.my_project.customer_orders", direction="parents") | ||
|
|
||
| 3. Getting both upstream and downstream lineage recursively and including tests: | ||
| get_model_lineage_dev(model_id="model.my_project.customer_orders", direction="both", exclude_prefixes=(), recursive=True) | ||
| </examples> |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,89 @@ | ||
| import pytest | ||
|
|
||
| from dbt_mcp.dbt_cli.models.lineage_types import ModelLineage | ||
|
|
||
|
|
||
| @pytest.fixture | ||
| def sample_manifest(): | ||
| yield { | ||
| "child_map": { | ||
| "model.a": ["model.b", "model.c"], | ||
| "model.b": ["model.d", "test.not_included"], | ||
| "model.c": [], | ||
| "model.d": [], | ||
| "source.1": ["model.a"], | ||
| }, | ||
| "parent_map": { | ||
| "model.b": ["model.a"], | ||
| "model.c": ["model.a"], | ||
| "model.d": ["model.b"], | ||
| "model.a": ["source.1"], | ||
| "source.1": [], | ||
| }, | ||
| "nodes": { | ||
| "model.a": {"name": "a"}, | ||
| "model.b": {"name": "b"}, | ||
| "model.c": {"name": "c"}, | ||
| "model.d": {"name": "d"}, | ||
| }, | ||
| "sources": { | ||
| "source.1": {"identifier": "1"}, | ||
| }, | ||
| "exposures": { | ||
| "exposure.1": {"name": "1"}, | ||
| }, | ||
| } | ||
|
|
||
|
|
||
| @pytest.mark.parametrize( | ||
| "model_id", | ||
| [ | ||
| pytest.param("model.a", id="using_full_model_id"), | ||
| pytest.param("a", id="using_model_name_only"), | ||
| ], | ||
| ) | ||
| def test_model_lineage_a__from_manifest(sample_manifest, model_id): | ||
| manifest = sample_manifest | ||
| lineage = ModelLineage.from_manifest( | ||
| manifest, model_id, direction="both", recursive=True | ||
| ) | ||
| assert lineage.model_id == "model.a" | ||
| assert lineage.parents[0].model_id == "source.1", ( | ||
| "Expected source.1 as parent to model.a" | ||
| ) | ||
| assert len(lineage.children) == 2, "Expected 2 children for model.a" | ||
| model_b = lineage.children[0] | ||
| assert model_b.model_id == "model.b", "Expected model.b as first child of model.a" | ||
| assert len(model_b.children) == 1, ( | ||
| "Expect test.not_included to be excluded from children of model.b" | ||
| ) | ||
| assert model_b.children[0].model_id == "model.d", ( | ||
| "Expected model.d as child of model.b" | ||
| ) | ||
|
|
||
|
|
||
| def test_model_lineage_b__from_manifest(sample_manifest): | ||
| manifest = sample_manifest | ||
| lineage_b = ModelLineage.from_manifest( | ||
| manifest, "model.b", direction="parents", recursive=True | ||
| ) | ||
| assert lineage_b.model_id == "model.b" | ||
| assert len(lineage_b.parents) == 1, "Expected 1 parent for model.b" | ||
|
|
||
| assert len(lineage_b.children) == 0, ( | ||
| "Expected no children when only fetching parents" | ||
| ) | ||
|
|
||
|
|
||
| def test_model_lineage__from_manifest_with_tests(sample_manifest): | ||
| manifest = sample_manifest | ||
|
|
||
| lineage = ModelLineage.from_manifest( | ||
| manifest, "model.a", direction="children", recursive=True, exclude_prefixes=() | ||
| ) | ||
| assert len(lineage.children) == 2, "Expected 2 children for model.a" | ||
| model_b = lineage.children[0] | ||
| assert model_b.model_id == "model.b", "Expected model.b as first child of model.a" | ||
| assert len(model_b.children) == 2, "Expected 2 children for model.b including tests" | ||
| assert lineage.children[0].children[1].model_id == "test.not_included" | ||
| assert len(lineage.parents) == 0, "Expected no parents when only fetching children" |
Uh oh!
There was an error while loading. Please reload this page.