diff --git a/sqlmesh/core/context.py b/sqlmesh/core/context.py index f9d54b0564..abcba043ab 100644 --- a/sqlmesh/core/context.py +++ b/sqlmesh/core/context.py @@ -115,6 +115,7 @@ ModelTestMetadata, generate_test, run_tests, + filter_tests_by_patterns, ) from sqlmesh.core.user import User from sqlmesh.utils import UniqueKeyDict, Verbosity @@ -146,8 +147,8 @@ from typing_extensions import Literal from sqlmesh.core.engine_adapter._typing import ( - BigframeSession, DF, + BigframeSession, PySparkDataFrame, PySparkSession, SnowparkSession, @@ -390,6 +391,8 @@ def __init__( self._standalone_audits: UniqueKeyDict[str, StandaloneAudit] = UniqueKeyDict( "standaloneaudits" ) + self._models_with_tests: t.Set[str] = set() + self._model_test_metadata: t.List[ModelTestMetadata] = [] self._macros: UniqueKeyDict[str, ExecutableOrMacro] = UniqueKeyDict("macros") self._metrics: UniqueKeyDict[str, Metric] = UniqueKeyDict("metrics") self._jinja_macros = JinjaMacroRegistry() @@ -639,6 +642,8 @@ def load(self, update_schemas: bool = True) -> GenericContext[C]: self._requirements.update(project.requirements) self._excluded_requirements.update(project.excluded_requirements) self._environment_statements.extend(project.environment_statements) + self._models_with_tests.update(project.models_with_tests) + self._model_test_metadata.extend(project.model_test_metadata) config = loader.config self._linters[config.project] = Linter.from_rules( @@ -1041,6 +1046,11 @@ def standalone_audits(self) -> MappingProxyType[str, StandaloneAudit]: """Returns all registered standalone audits in this context.""" return MappingProxyType(self._standalone_audits) + @property + def models_with_tests(self) -> t.Set[str]: + """Returns all models with tests in this context.""" + return self._models_with_tests + @property def snapshots(self) -> t.Dict[str, Snapshot]: """Generates and returns snapshots based on models registered in this context. @@ -2212,7 +2222,9 @@ def test( pd.set_option("display.max_columns", None) - test_meta = self.load_model_tests(tests=tests, patterns=match_patterns) + test_meta = self._filter_preloaded_tests( + test_meta=self._model_test_metadata, tests=tests, patterns=match_patterns + ) result = run_tests( model_test_metadata=test_meta, @@ -2773,6 +2785,35 @@ def _get_engine_adapter(self, gateway: t.Optional[str] = None) -> EngineAdapter: raise SQLMeshError(f"Gateway '{gateway}' not found in the available engine adapters.") return self.engine_adapter + def _filter_preloaded_tests( + self, + test_meta: t.List[ModelTestMetadata], + tests: t.Optional[t.List[str]] = None, + patterns: t.Optional[t.List[str]] = None, + ) -> t.List[ModelTestMetadata]: + """Filter pre-loaded test metadata based on tests and patterns.""" + + if tests: + filtered_tests = [] + for test in tests: + if "::" in test: + filename, test_name = test.split("::", maxsplit=1) + filtered_tests.extend( + [ + t + for t in test_meta + if str(t.path) == filename and t.test_name == test_name + ] + ) + else: + filtered_tests.extend([t for t in test_meta if str(t.path) == test]) + test_meta = filtered_tests + + if patterns: + test_meta = filter_tests_by_patterns(test_meta, patterns) + + return test_meta + def _snapshots( self, models_override: t.Optional[UniqueKeyDict[str, Model]] = None ) -> t.Dict[str, Snapshot]: diff --git a/sqlmesh/core/linter/rules/builtin.py b/sqlmesh/core/linter/rules/builtin.py index f6bef4b4ef..5058f3a58a 100644 --- a/sqlmesh/core/linter/rules/builtin.py +++ b/sqlmesh/core/linter/rules/builtin.py @@ -129,6 +129,21 @@ def check_model(self, model: Model) -> t.Optional[RuleViolation]: return self.violation() +class NoMissingUnitTest(Rule): + """All models must have a unit test found in the test/ directory yaml files""" + + def check_model(self, model: Model) -> t.Optional[RuleViolation]: + # External models cannot have unit tests + if isinstance(model, ExternalModel): + return None + + if model.name not in self.context.models_with_tests: + return self.violation( + violation_msg=f"Model {model.name} is missing unit test(s). Please add in the tests/ directory." + ) + return None + + class NoMissingExternalModels(Rule): """All external models must be registered in the external_models.yaml file""" diff --git a/sqlmesh/core/loader.py b/sqlmesh/core/loader.py index 6647a2edba..32c13de215 100644 --- a/sqlmesh/core/loader.py +++ b/sqlmesh/core/loader.py @@ -64,6 +64,8 @@ class LoadedProject: excluded_requirements: t.Set[str] environment_statements: t.List[EnvironmentStatements] user_rules: RuleSet + model_test_metadata: t.List[ModelTestMetadata] + models_with_tests: t.Set[str] class CacheBase(abc.ABC): @@ -243,6 +245,12 @@ def load(self) -> LoadedProject: user_rules = self._load_linting_rules() + model_test_metadata = self.load_model_tests() + + models_with_tests = { + model_test_metadata.model_name for model_test_metadata in model_test_metadata + } + project = LoadedProject( macros=macros, jinja_macros=jinja_macros, @@ -254,6 +262,8 @@ def load(self) -> LoadedProject: excluded_requirements=excluded_requirements, environment_statements=environment_statements, user_rules=user_rules, + model_test_metadata=model_test_metadata, + models_with_tests=models_with_tests, ) return project diff --git a/sqlmesh/core/test/discovery.py b/sqlmesh/core/test/discovery.py index 0f60fe6fa9..5e2452b570 100644 --- a/sqlmesh/core/test/discovery.py +++ b/sqlmesh/core/test/discovery.py @@ -20,6 +20,10 @@ class ModelTestMetadata(PydanticModel): def fully_qualified_test_name(self) -> str: return f"{self.path}::{self.test_name}" + @property + def model_name(self) -> str: + return self.body["model"] + def __hash__(self) -> int: return self.fully_qualified_test_name.__hash__() diff --git a/tests/core/linter/test_builtin.py b/tests/core/linter/test_builtin.py index 1a19d036b5..0ff91470ff 100644 --- a/tests/core/linter/test_builtin.py +++ b/tests/core/linter/test_builtin.py @@ -172,3 +172,63 @@ def test_no_missing_external_models_with_existing_file_not_ending_in_newline( ) fix_path = sushi_path / "external_models.yaml" assert edit.path == fix_path + + +def test_no_missing_unit_tests(tmp_path, copy_to_temp_path): + """ + Tests that the NoMissingUnitTest linter rule correctly identifies models + without corresponding unit tests in the tests/ directory + + This test checks the sushi example project, enables the linter, + and verifies that the linter raises a rule violation for the models + that do not have a unit test + """ + sushi_paths = copy_to_temp_path("examples/sushi") + sushi_path = sushi_paths[0] + + # Override the config.py to turn on lint + with open(sushi_path / "config.py", "r") as f: + read_file = f.read() + + before = """ linter=LinterConfig( + enabled=False, + rules=[ + "ambiguousorinvalidcolumn", + "invalidselectstarexpansion", + "noselectstar", + "nomissingaudits", + "nomissingowner", + "nomissingexternalmodels", + ], + ),""" + after = """linter=LinterConfig(enabled=True, rules=["nomissingunittest"]),""" + read_file = read_file.replace(before, after) + assert after in read_file + with open(sushi_path / "config.py", "w") as f: + f.writelines(read_file) + + # Load the context with the temporary sushi path + context = Context(paths=[sushi_path]) + + # Lint the models + lints = context.lint_models(raise_on_error=False) + + # Should have violations for models without tests (most models except customers) + assert len(lints) >= 1 + + # Check that we get violations for models without tests + violation_messages = [lint.violation_msg for lint in lints] + assert any("is missing unit test(s)" in msg for msg in violation_messages) + + # Check that models with existing tests don't have violations + models_with_tests = ["customer_revenue_by_day", "customer_revenue_lifetime", "order_items"] + + for model_name in models_with_tests: + model_violations = [ + lint + for lint in lints + if model_name in lint.violation_msg and "is missing unit test(s)" in lint.violation_msg + ] + assert len(model_violations) == 0, ( + f"Model {model_name} should not have a violation since it has a test" + )