Skip to content

Commit 4fc3ba6

Browse files
authored
Feat(dbt_cli): Add --select and --exclude options (#5200)
1 parent 4209672 commit 4fc3ba6

File tree

7 files changed

+346
-22
lines changed

7 files changed

+346
-22
lines changed

sqlmesh_dbt/cli.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,18 @@ def _get_dbt_operations(ctx: click.Context) -> DbtOperations:
1212
return ctx.obj
1313

1414

15+
select_option = click.option(
16+
"-s",
17+
"-m",
18+
"--select",
19+
"--models",
20+
"--model",
21+
multiple=True,
22+
help="Specify the nodes to include.",
23+
)
24+
exclude_option = click.option("--exclude", multiple=True, help="Specify the nodes to exclude.")
25+
26+
1527
@click.group(invoke_without_command=True)
1628
@click.option("--profile", help="Which existing profile to load. Overrides output.profile")
1729
@click.option("-t", "--target", help="Which target to load for the given profile")
@@ -38,23 +50,26 @@ def dbt(
3850

3951

4052
@dbt.command()
41-
@click.option("-s", "-m", "--select", "--models", "--model", help="Specify the nodes to include.")
53+
@select_option
54+
@exclude_option
4255
@click.option(
4356
"-f",
4457
"--full-refresh",
4558
help="If specified, dbt will drop incremental models and fully-recalculate the incremental table from the model definition.",
4659
)
4760
@click.pass_context
48-
def run(ctx: click.Context, select: t.Optional[str], full_refresh: bool) -> None:
61+
def run(ctx: click.Context, **kwargs: t.Any) -> None:
4962
"""Compile SQL and execute against the current target database."""
50-
_get_dbt_operations(ctx).run(select=select, full_refresh=full_refresh)
63+
_get_dbt_operations(ctx).run(**kwargs)
5164

5265

5366
@dbt.command(name="list")
67+
@select_option
68+
@exclude_option
5469
@click.pass_context
55-
def list_(ctx: click.Context) -> None:
70+
def list_(ctx: click.Context, **kwargs: t.Any) -> None:
5671
"""List the resources in your project"""
57-
_get_dbt_operations(ctx).list_()
72+
_get_dbt_operations(ctx).list_(**kwargs)
5873

5974

6075
@dbt.command(name="ls", hidden=True) # hidden alias for list

sqlmesh_dbt/console.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,27 @@
1+
import typing as t
12
from sqlmesh.core.console import TerminalConsole
3+
from sqlmesh.core.model import Model
4+
from rich.tree import Tree
25

36

47
class DbtCliConsole(TerminalConsole):
5-
# TODO: build this out
6-
78
def print(self, msg: str) -> None:
89
return self._print(msg)
10+
11+
def list_models(
12+
self, models: t.List[Model], list_parents: bool = True, list_audits: bool = True
13+
) -> None:
14+
model_list = Tree("[bold]Models in project:[/bold]")
15+
16+
for model in models:
17+
model_tree = model_list.add(model.name)
18+
19+
if list_parents:
20+
for parent in model.depends_on:
21+
model_tree.add(f"depends_on: {parent}")
22+
23+
if list_audits:
24+
for audit_name in model.audit_definitions:
25+
model_tree.add(f"audit: {audit_name}")
26+
27+
self._print(model_list)

sqlmesh_dbt/operations.py

Lines changed: 41 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,35 +2,46 @@
22
import typing as t
33
from rich.progress import Progress
44
from pathlib import Path
5+
import logging
6+
from sqlmesh_dbt import selectors
57

68
if t.TYPE_CHECKING:
79
# important to gate these to be able to defer importing sqlmesh until we need to
810
from sqlmesh.core.context import Context
911
from sqlmesh.dbt.project import Project
1012
from sqlmesh_dbt.console import DbtCliConsole
13+
from sqlmesh.core.model import Model
14+
15+
logger = logging.getLogger(__name__)
1116

1217

1318
class DbtOperations:
1419
def __init__(self, sqlmesh_context: Context, dbt_project: Project):
1520
self.context = sqlmesh_context
1621
self.project = dbt_project
1722

18-
def list_(self) -> None:
19-
for _, model in self.context.models.items():
20-
self.console.print(model.name)
21-
22-
def run(self, select: t.Optional[str] = None, full_refresh: bool = False) -> None:
23-
# A dbt run both updates data and changes schemas and has no way of rolling back so more closely maps to a SQLMesh forward-only plan
24-
# TODO: if --full-refresh specified, mark incrementals as breaking instead of forward_only?
25-
26-
# TODO: we need to either convert DBT selector syntax to SQLMesh selector syntax
27-
# or make the model selection engine configurable
23+
def list_(
24+
self,
25+
select: t.Optional[t.List[str]] = None,
26+
exclude: t.Optional[t.List[str]] = None,
27+
) -> None:
28+
# dbt list prints:
29+
# - models
30+
# - "data tests" (audits) for those models
31+
# it also applies selectors which is useful for testing selectors
32+
selected_models = list(self._selected_models(select, exclude).values())
33+
self.console.list_models(selected_models)
34+
35+
def run(
36+
self,
37+
select: t.Optional[t.List[str]] = None,
38+
exclude: t.Optional[t.List[str]] = None,
39+
full_refresh: bool = False,
40+
) -> None:
2841
select_models = None
29-
if select:
30-
if "," in select:
31-
select_models = select.split(",")
32-
else:
33-
select_models = select.split(" ")
42+
43+
if sqlmesh_selector := selectors.to_sqlmesh(select or [], exclude or []):
44+
select_models = [sqlmesh_selector]
3445

3546
self.context.plan(
3647
select_models=select_models,
@@ -40,6 +51,21 @@ def run(self, select: t.Optional[str] = None, full_refresh: bool = False) -> Non
4051
auto_apply=True,
4152
)
4253

54+
def _selected_models(
55+
self, select: t.Optional[t.List[str]] = None, exclude: t.Optional[t.List[str]] = None
56+
) -> t.Dict[str, Model]:
57+
if sqlmesh_selector := selectors.to_sqlmesh(select or [], exclude or []):
58+
model_selector = self.context._new_selector()
59+
selected_models = {
60+
fqn: model
61+
for fqn, model in self.context.models.items()
62+
if fqn in model_selector.expand_model_selections([sqlmesh_selector])
63+
}
64+
else:
65+
selected_models = dict(self.context.models)
66+
67+
return selected_models
68+
4369
@property
4470
def console(self) -> DbtCliConsole:
4571
console = self.context.console

sqlmesh_dbt/selectors.py

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
import typing as t
2+
import logging
3+
4+
logger = logging.getLogger(__name__)
5+
6+
7+
def to_sqlmesh(dbt_select: t.Collection[str], dbt_exclude: t.Collection[str]) -> t.Optional[str]:
8+
"""
9+
Given selectors defined in the format of the dbt cli --select and --exclude arguments, convert them into a selector expression that
10+
the SQLMesh selector engine can understand.
11+
12+
The main things being mapped are:
13+
- set union (" " between items within the same selector string OR multiple --select arguments) is mapped to " | "
14+
- set intersection ("," between items within the same selector string) is mapped to " & "
15+
- `--exclude`. The SQLMesh selector engine does not treat this as a separate parameter and rather treats exclusion as a normal selector
16+
that just happens to contain negation syntax, so we generate these by negating each expression and then intersecting the result
17+
with any --select expressions
18+
19+
Things that are *not* currently being mapped include:
20+
- selectors based on file paths
21+
- selectors based on partially qualified names like "model_a". The SQLMesh selector engine requires either:
22+
- wildcards, eg "*model_a*"
23+
- the full model name qualified with the schema, eg "staging.model_a"
24+
25+
Examples:
26+
--select "model_a"
27+
-> "model_a"
28+
--select "main.model_a"
29+
-> "main.model_a"
30+
--select "main.model_a" --select "main.model_b"
31+
-> "main.model_a | main.model_b"
32+
--select "main.model_a main.model_b"
33+
-> "main.model_a | main.model_b"
34+
--select "(main.model_a+ & ^main.model_b)"
35+
-> "(main.model_a+ & ^main.model_b)"
36+
--select "+main.model_a" --exclude "raw.src_data"
37+
-> "+main.model_a & ^(raw.src_data)"
38+
--select "+main.model_a" --select "main.*b+" --exclude "raw.src_data"
39+
-> "(+main.model_a | main.*b+) & ^(raw.src_data)"
40+
"""
41+
if not dbt_select and not dbt_exclude:
42+
return None
43+
44+
select_expr = " | ".join(_to_sqlmesh(expr) for expr in dbt_select)
45+
select_expr = _wrap(select_expr) if dbt_exclude and len(dbt_select) > 1 else select_expr
46+
47+
exclude_expr = " | ".join(_to_sqlmesh(expr, negate=True) for expr in dbt_exclude)
48+
exclude_expr = _wrap(exclude_expr) if dbt_select and len(dbt_exclude) > 1 else exclude_expr
49+
50+
main_expr = " & ".join([expr for expr in [select_expr, exclude_expr] if expr])
51+
52+
logger.debug(
53+
f"Expanded dbt select: {dbt_select}, exclude: {dbt_exclude} into SQLMesh: {main_expr}"
54+
)
55+
56+
return main_expr
57+
58+
59+
def _to_sqlmesh(selector_str: str, negate: bool = False) -> str:
60+
unions, intersections = _split_unions_and_intersections(selector_str)
61+
62+
if negate:
63+
unions = [_negate(u) for u in unions]
64+
intersections = [_negate(i) for i in intersections]
65+
66+
union_expr = " | ".join(unions)
67+
intersection_expr = " & ".join(intersections)
68+
69+
if len(unions) > 1 and intersections:
70+
union_expr = f"({union_expr})"
71+
72+
if len(intersections) > 1 and unions:
73+
intersection_expr = f"({intersection_expr})"
74+
75+
return " | ".join([expr for expr in [union_expr, intersection_expr] if expr])
76+
77+
78+
def _split_unions_and_intersections(selector_str: str) -> t.Tuple[t.List[str], t.List[str]]:
79+
# break space-separated items like: "my_first_model my_second_model" into a list of selectors to union
80+
# and comma-separated items like: "my_first_model,my_second_model" into a list of selectors to intersect
81+
# but, take into account brackets, eg "(my_first_model & my_second_model)" should not be split
82+
83+
def _split_by(input: str, delimiter: str) -> t.Iterator[str]:
84+
buf = ""
85+
depth = 0
86+
87+
for char in input:
88+
if char == delimiter and depth <= 0:
89+
# only split on a space if we are not within parenthesis
90+
yield buf
91+
buf = ""
92+
continue
93+
elif char == "(":
94+
depth += 1
95+
elif char == ")":
96+
depth -= 1
97+
98+
buf += char
99+
100+
if buf:
101+
yield buf
102+
103+
# first, break up based on spaces
104+
segments = list(_split_by(selector_str, " "))
105+
106+
# then, within each segment, identify the unions and intersections
107+
unions = []
108+
intersections = []
109+
110+
for segment in segments:
111+
maybe_intersections = list(_split_by(segment, ","))
112+
if len(maybe_intersections) > 1:
113+
intersections.extend(maybe_intersections)
114+
else:
115+
unions.append(segment)
116+
117+
return unions, intersections
118+
119+
120+
def _negate(expr: str) -> str:
121+
return f"^{_wrap(expr)}"
122+
123+
124+
def _wrap(expr: str) -> str:
125+
already_wrapped = expr.strip().startswith("(") and expr.strip().endswith(")")
126+
127+
if expr and not already_wrapped:
128+
return f"({expr})"
129+
130+
return expr

tests/dbt/cli/test_list.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,34 @@ def test_list(jaffle_shop_duckdb: Path, invoke_cli: t.Callable[..., Result]):
1515
assert "main.orders" in result.output
1616
assert "main.customers" in result.output
1717
assert "main.stg_payments" in result.output
18+
assert "main.raw_orders" in result.output
19+
20+
21+
def test_list_select(jaffle_shop_duckdb: Path, invoke_cli: t.Callable[..., Result]):
22+
result = invoke_cli(["list", "--select", "main.raw_customers+"])
23+
24+
assert result.exit_code == 0
25+
assert not result.exception
26+
27+
assert "main.orders" in result.output
28+
assert "main.customers" in result.output
29+
assert "main.stg_customers" in result.output
30+
assert "main.raw_customers" in result.output
31+
32+
assert "main.stg_payments" not in result.output
33+
assert "main.raw_orders" not in result.output
34+
35+
36+
def test_list_select_exclude(jaffle_shop_duckdb: Path, invoke_cli: t.Callable[..., Result]):
37+
result = invoke_cli(["list", "--select", "main.raw_customers+", "--exclude", "main.orders"])
38+
39+
assert result.exit_code == 0
40+
assert not result.exception
41+
42+
assert "main.customers" in result.output
43+
assert "main.stg_customers" in result.output
44+
assert "main.raw_customers" in result.output
45+
46+
assert "main.orders" not in result.output
47+
assert "main.stg_payments" not in result.output
48+
assert "main.raw_orders" not in result.output

tests/dbt/cli/test_run.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
import pytest
33
from pathlib import Path
44
from click.testing import Result
5+
import time_machine
6+
from tests.cli.test_cli import FREEZE_TIME
57

68
pytestmark = pytest.mark.slow
79

@@ -13,3 +15,26 @@ def test_run(jaffle_shop_duckdb: Path, invoke_cli: t.Callable[..., Result]):
1315
assert not result.exception
1416

1517
assert "Model batches executed" in result.output
18+
19+
20+
def test_run_with_selectors(jaffle_shop_duckdb: Path, invoke_cli: t.Callable[..., Result]):
21+
with time_machine.travel(FREEZE_TIME):
22+
# do an initial run to create the objects
23+
# otherwise the selected subset may depend on something that hasnt been created
24+
result = invoke_cli(["run"])
25+
assert result.exit_code == 0
26+
assert "main.orders" in result.output
27+
28+
result = invoke_cli(["run", "--select", "main.raw_customers+", "--exclude", "main.orders"])
29+
30+
assert result.exit_code == 0
31+
assert not result.exception
32+
33+
assert "main.stg_customers" in result.output
34+
assert "main.stg_orders" in result.output
35+
assert "main.stg_payments" in result.output
36+
assert "main.customers" in result.output
37+
38+
assert "main.orders" not in result.output
39+
40+
assert "Model batches executed" in result.output

0 commit comments

Comments
 (0)