Skip to content
22 changes: 12 additions & 10 deletions metadata-ingestion/src/datahub/configuration/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import unittest.mock
from abc import ABC, abstractmethod
from enum import auto
from functools import cached_property as functools_cached_property
from typing import (
IO,
TYPE_CHECKING,
Expand Down Expand Up @@ -133,7 +134,7 @@ def _config_model_schema_extra(schema: Dict[str, Any], model: Type[BaseModel]) -
class ConfigModel(BaseModel):
model_config = ConfigDict(
extra="forbid",
ignored_types=(cached_property,),
ignored_types=(cached_property, functools_cached_property),
json_schema_extra=_config_model_schema_extra,
hide_input_in_errors=not get_debug(),
)
Expand Down Expand Up @@ -384,6 +385,14 @@ class AllowDenyPattern(ConfigModel):
def regex_flags(self) -> int:
return re.IGNORECASE if self.ignoreCase else 0

@functools_cached_property
def _compiled_allow(self) -> "List[re.Pattern]":
return [re.compile(pattern, self.regex_flags) for pattern in self.allow]

@functools_cached_property
def _compiled_deny(self) -> "List[re.Pattern]":
return [re.compile(pattern, self.regex_flags) for pattern in self.deny]

@classmethod
def allow_all(cls) -> "AllowDenyPattern":
return AllowDenyPattern()
Expand All @@ -392,17 +401,10 @@ def allowed(self, string: str) -> bool:
if self.denied(string):
return False

return any(
re.match(allow_pattern, string, self.regex_flags)
for allow_pattern in self.allow
)
return any(pattern.match(string) for pattern in self._compiled_allow)

def denied(self, string: str) -> bool:
for deny_pattern in self.deny:
if re.match(deny_pattern, string, self.regex_flags):
return True

return False
return any(pattern.match(string) for pattern in self._compiled_deny)

def is_fully_specified_allow_list(self) -> bool:
"""
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import logging
import re
from collections import defaultdict
from dataclasses import dataclass
from enum import Enum
from functools import cached_property
from typing import Dict, List, Optional, Set

import pydantic
Expand Down Expand Up @@ -404,6 +406,13 @@ class SnowflakeV2Config(
"This may be required in the case of _eg_ temporary tables being created in a different database than the ones in the database_name patterns.",
)

@cached_property
def _compiled_temporary_tables_pattern(self) -> "List[re.Pattern[str]]":
return [
re.compile(pattern, re.IGNORECASE)
for pattern in self.temporary_tables_pattern
]

@field_validator("convert_urns_to_lowercase", mode="after")
@classmethod
def validate_convert_urns_to_lowercase(cls, v):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import tempfile
from dataclasses import dataclass
from datetime import datetime, timezone
from functools import cached_property
from typing import Any, Dict, Iterable, List, Optional, Union

import pydantic
Expand Down Expand Up @@ -147,6 +148,13 @@ class SnowflakeQueriesExtractorConfig(ConfigModel):

query_dedup_strategy: QueryDedupStrategyType = QueryDedupStrategyType.STANDARD

@cached_property
def _compiled_temporary_tables_pattern(self) -> "List[re.Pattern[str]]":
return [
re.compile(pattern, re.IGNORECASE)
for pattern in self.temporary_tables_pattern
]


class SnowflakeQueriesSourceConfig(
SnowflakeQueriesExtractorConfig, SnowflakeIdentifierConfig, SnowflakeFilterConfig
Expand Down Expand Up @@ -284,8 +292,8 @@ def local_temp_path(self) -> pathlib.Path:

def is_temp_table(self, name: str) -> bool:
if any(
re.match(pattern, name, flags=re.IGNORECASE)
for pattern in self.config.temporary_tables_pattern
pattern.match(name)
for pattern in self.config._compiled_temporary_tables_pattern
):
return True

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import os
import os.path
import platform
import re
from dataclasses import dataclass
from typing import Dict, Iterable, List, Optional, Union

Expand Down Expand Up @@ -469,8 +468,8 @@ class SnowflakePrivilege:

def _is_temp_table(self, name: str) -> bool:
if any(
re.match(pattern, name, flags=re.IGNORECASE)
for pattern in self.config.temporary_tables_pattern
pattern.match(name)
for pattern in self.config._compiled_temporary_tables_pattern
):
return True

Expand Down
Loading