Skip to content

Commit f2e10c5

Browse files
committed
Add support for model jobs
1 parent a4e3b06 commit f2e10c5

File tree

14 files changed

+160
-14
lines changed

14 files changed

+160
-14
lines changed

dlt/common/data_writers/writers.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
from dlt.common.libs.pyarrow import pyarrow as pa
4545

4646

47-
TDataItemFormat = Literal["arrow", "object", "file"]
47+
TDataItemFormat = Literal["arrow", "object", "file", "model"]
4848
TWriter = TypeVar("TWriter", bound="DataWriter")
4949

5050

@@ -59,6 +59,8 @@ class FileWriterSpec(NamedTuple):
5959
"""File format supports changes of schema: True - at any moment, Buffer - in memory buffer before opening file, False - not at all"""
6060
requires_destination_capabilities: bool = False
6161
supports_compression: bool = False
62+
file_max_items: Optional[int] = None
63+
"""Set an upper limit on the number of items in one file"""
6264

6365

6466
EMPTY_DATA_WRITER_METRICS = DataWriterMetrics("", 0, 0, 2**32, 0.0)
@@ -115,6 +117,8 @@ def item_format_from_file_extension(cls, extension: str) -> TDataItemFormat:
115117
return "object"
116118
elif extension == "parquet":
117119
return "arrow"
120+
elif extension == "model":
121+
return "model"
118122
# those files may be imported by normalizer as is
119123
elif extension in LOADER_FILE_FORMATS:
120124
return "file"
@@ -175,6 +179,32 @@ def writer_spec(cls) -> FileWriterSpec:
175179
)
176180

177181

182+
class ModelWriter(DataWriter):
183+
"""Writes incoming items row by row into a text file and ensures a trailing ;"""
184+
185+
def write_header(self, columns_schema: TTableSchemaColumns) -> None:
186+
pass
187+
188+
def write_data(self, items: Sequence[TDataItem]) -> None:
189+
super().write_data(items)
190+
self.items_count += len(items)
191+
for item in items:
192+
self._f.write(item + "\n")
193+
194+
@classmethod
195+
def writer_spec(cls) -> FileWriterSpec:
196+
return FileWriterSpec(
197+
"model",
198+
"model",
199+
file_extension="model",
200+
is_binary_format=False,
201+
supports_schema_changes="True",
202+
supports_compression=False,
203+
# NOTE: we create a new model file for each sql row
204+
file_max_items=1,
205+
)
206+
207+
178208
class TypedJsonlListWriter(JsonlWriter):
179209
def write_data(self, items: Sequence[TDataItem]) -> None:
180210
# skip JsonlWriter when calling super
@@ -670,6 +700,7 @@ def is_native_writer(writer_type: Type[DataWriter]) -> bool:
670700
ArrowToJsonlWriter,
671701
ArrowToTypedJsonlListWriter,
672702
ArrowToCsvWriter,
703+
ModelWriter,
673704
]
674705

675706
WRITER_SPECS: Dict[FileWriterSpec, Type[DataWriter]] = {
@@ -689,6 +720,11 @@ def is_native_writer(writer_type: Type[DataWriter]) -> bool:
689720
for writer in ALL_WRITERS
690721
if writer.writer_spec().data_item_format == "arrow" and is_native_writer(writer)
691722
),
723+
"model": tuple(
724+
writer
725+
for writer in ALL_WRITERS
726+
if writer.writer_spec().data_item_format == "model" and is_native_writer(writer)
727+
),
692728
}
693729

694730

dlt/common/destination/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ def verify_supported_data_types(
152152
for parsed_file in new_jobs:
153153
formats = table_file_formats.setdefault(parsed_file.table_name, set())
154154
if parsed_file.file_format in LOADER_FILE_FORMATS:
155-
formats.add(parsed_file.file_format) # type: ignore[arg-type]
155+
formats.add(parsed_file.file_format)
156156
# all file formats
157157
all_file_formats = set(capabilities.supported_loader_file_formats or []) | set(
158158
capabilities.supported_staging_file_formats or []

dlt/common/storages/data_item_storage.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,9 @@ def _get_writer(
2828
if not writer:
2929
# assign a writer for each table
3030
path = self._get_data_item_path_template(load_id, schema_name, table_name)
31-
writer = BufferedDataWriter(self.writer_spec, path)
31+
writer = BufferedDataWriter(
32+
self.writer_spec, path, file_max_items=self.writer_spec.file_max_items
33+
)
3234
self.buffered_writers[writer_id] = writer
3335
return writer
3436

dlt/common/typing.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,9 @@ class SecretSentinel:
127127
VARIANT_FIELD_FORMAT = "v_%s"
128128
TFileOrPath = Union[str, PathLike, IO[Any]]
129129
TSortOrder = Literal["asc", "desc"]
130-
TLoaderFileFormat = Literal["jsonl", "typed-jsonl", "insert_values", "parquet", "csv", "reference"]
130+
TLoaderFileFormat = Literal[
131+
"jsonl", "typed-jsonl", "insert_values", "parquet", "csv", "reference", "sql", "model"
132+
]
131133
"""known loader file formats"""
132134

133135
TDynHintType = TypeVar("TDynHintType")

dlt/destinations/impl/duckdb/factory.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ class duckdb(Destination[DuckDbClientConfiguration, "DuckDbClient"]):
129129
def _raw_capabilities(self) -> DestinationCapabilitiesContext:
130130
caps = DestinationCapabilitiesContext()
131131
caps.preferred_loader_file_format = "insert_values"
132-
caps.supported_loader_file_formats = ["insert_values", "parquet", "jsonl"]
132+
caps.supported_loader_file_formats = ["insert_values", "parquet", "jsonl", "model"]
133133
caps.preferred_staging_file_format = None
134134
caps.supported_staging_file_formats = []
135135
caps.type_mapper = DuckDbTypeMapper

dlt/destinations/job_client_impl.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,26 @@ def is_sql_job(file_path: str) -> bool:
113113
return os.path.splitext(file_path)[1][1:] == "sql"
114114

115115

116+
class ModelLoadJob(RunnableLoadJob):
117+
"""
118+
A job to insert rows into a table from a model file which contains a list of select statements
119+
"""
120+
121+
def __init__(self, file_path: str) -> None:
122+
super().__init__(file_path)
123+
self._job_client: "SqlJobClientBase" = None
124+
125+
def run(self) -> None:
126+
with FileStorage.open_zipsafe_ro(self._file_path, "r", encoding="utf-8") as f:
127+
sql = f.read()
128+
self._sql_client = self._job_client.sql_client
129+
self._sql_client.execute_sql(sql)
130+
131+
@staticmethod
132+
def is_model_job(file_path: str) -> bool:
133+
return os.path.splitext(file_path)[1][1:] == "model"
134+
135+
116136
class CopyRemoteFileLoadJob(RunnableLoadJob, HasFollowupJobs):
117137
def __init__(
118138
self,
@@ -279,6 +299,9 @@ def create_load_job(
279299
if SqlLoadJob.is_sql_job(file_path):
280300
# create sql load job
281301
return SqlLoadJob(file_path)
302+
if ModelLoadJob.is_model_job(file_path):
303+
# create model load job
304+
return ModelLoadJob(file_path)
282305
return None
283306

284307
def complete_load(self, load_id: str) -> None:

dlt/extract/extract.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@
4545
from dlt.extract.reference import SourceReference
4646
from dlt.extract.resource import DltResource
4747
from dlt.extract.storage import ExtractStorage
48-
from dlt.extract.extractors import ObjectExtractor, ArrowExtractor, Extractor
48+
from dlt.extract.extractors import ObjectExtractor, ArrowExtractor, Extractor, TextExtractor
4949
from dlt.extract.utils import get_data_item_format
5050

5151

@@ -343,6 +343,9 @@ def _extract_single_source(
343343
"arrow": ArrowExtractor(
344344
load_id, self.extract_storage.item_storages["arrow"], schema, collector=collector
345345
),
346+
"model": TextExtractor(
347+
load_id, self.extract_storage.item_storages["model"], schema, collector=collector
348+
),
346349
}
347350
# make sure we close storage on exception
348351
with collector(f"Extract {source.name}"):
@@ -363,7 +366,7 @@ def _extract_single_source(
363366
collector.update("Resources", delta)
364367
signals.raise_if_signalled()
365368
resource = source.resources[pipe_item.pipe.name]
366-
item_format = get_data_item_format(pipe_item.item)
369+
item_format = get_data_item_format(pipe_item.item, pipe_item.meta)
367370
extractors[item_format].write_items(
368371
resource, pipe_item.item, pipe_item.meta
369372
)

dlt/extract/extractors.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def __init__(
6262
hints: TResourceHints = None,
6363
create_table_variant: bool = None,
6464
) -> None:
65-
super().__init__(hints, create_table_variant)
65+
super().__init__(hints=hints, create_table_variant=create_table_variant)
6666
self.file_path = file_path
6767
self.metrics = metrics
6868
self.file_format = file_format
@@ -292,6 +292,12 @@ class ObjectExtractor(Extractor):
292292
pass
293293

294294

295+
class TextExtractor(Extractor):
296+
"""Extracts text items and writes them row by row into a text file"""
297+
298+
pass
299+
300+
295301
class ArrowExtractor(Extractor):
296302
"""Extracts arrow data items into parquet. Normalizes arrow items column names.
297303
Compares the arrow schema to actual dlt table schema to reorder the columns and to

dlt/extract/hints.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
from dlt.extract.items_transform import ValidateItem
4444
from dlt.extract.utils import ensure_table_schema_columns, ensure_table_schema_columns_hint
4545
from dlt.extract.validation import create_item_validator
46+
from dlt.common.data_writers import TDataItemFormat
4647

4748

4849
class TResourceHintsBase(TypedDict, total=False):
@@ -68,11 +69,17 @@ class TResourceHints(TResourceHintsBase, total=False):
6869

6970

7071
class HintsMeta:
71-
__slots__ = ("hints", "create_table_variant")
72+
__slots__ = ("hints", "create_table_variant", "data_item_format")
7273

73-
def __init__(self, hints: TResourceHints, create_table_variant: bool) -> None:
74+
def __init__(
75+
self,
76+
hints: TResourceHints,
77+
create_table_variant: bool,
78+
data_item_format: TDataItemFormat = None,
79+
) -> None:
7480
self.hints = hints
7581
self.create_table_variant = create_table_variant
82+
self.data_item_format = data_item_format
7683

7784

7885
NATURAL_CALLABLES = ["incremental", "validator", "original_columns"]

dlt/extract/resource.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@
2929
pipeline_state,
3030
)
3131
from dlt.common.utils import flatten_list_or_items, get_callable_name, uniq_id
32+
from dlt.common.data_writers import TDataItemFormat
33+
3234
from dlt.common.schema.typing import TTableSchema
3335
from dlt.extract.utils import wrap_async_iterator, wrap_parallel_iterator
3436

@@ -72,7 +74,10 @@ def with_table_name(item: TDataItems, table_name: str) -> DataItemWithMeta:
7274

7375

7476
def with_hints(
75-
item: TDataItems, hints: TResourceHints, create_table_variant: bool = False
77+
item: TDataItems,
78+
hints: TResourceHints = None,
79+
create_table_variant: bool = False,
80+
data_item_format: TDataItemFormat = None,
7681
) -> DataItemWithMeta:
7782
"""Marks `item` to update the resource with specified `hints`.
7883
@@ -81,7 +86,7 @@ def with_hints(
8186
Create `TResourceHints` with `make_hints`.
8287
Setting `table_name` will dispatch the `item` to a specified table, like `with_table_name`
8388
"""
84-
return DataItemWithMeta(HintsMeta(hints, create_table_variant), item)
89+
return DataItemWithMeta(HintsMeta(hints or {}, create_table_variant, data_item_format), item)
8590

8691

8792
TDltResourceImpl = TypeVar("TDltResourceImpl", bound="DltResource", default="DltResource")

dlt/extract/storage.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,9 @@ def __init__(self, config: NormalizeStorageConfiguration) -> None:
5050
"arrow": ExtractorItemStorage(
5151
self.new_packages, DataWriter.writer_spec_from_file_format("parquet", "arrow")
5252
),
53+
"model": ExtractorItemStorage(
54+
self.new_packages, DataWriter.writer_spec_from_file_format("model", "model")
55+
),
5356
}
5457

5558
def create_load_package(self, schema: Schema, reuse_exiting_package: bool = True) -> str:

dlt/extract/utils.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@
4343
SupportsPipe,
4444
)
4545

46+
from dlt.common.schema.typing import TFileFormat
47+
4648
try:
4749
from dlt.common.libs import pydantic
4850
except MissingDependencyException:
@@ -60,14 +62,21 @@
6062
pandas = None
6163

6264

63-
def get_data_item_format(items: TDataItems) -> TDataItemFormat:
65+
def get_data_item_format(items: TDataItems, meta: Any = None) -> TDataItemFormat:
6466
"""Detect the format of the data item from `items`.
6567
6668
Reverts to `object` for empty lists
6769
6870
Returns:
6971
The data file format.
7072
"""
73+
74+
# if incoming item is hints meta, check if item format is forced
75+
from dlt.extract.hints import HintsMeta
76+
77+
if isinstance(meta, HintsMeta) and meta.data_item_format:
78+
return meta.data_item_format
79+
7180
if not pyarrow and not pandas:
7281
return "object"
7382

dlt/normalize/worker.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ def _get_items_normalizer(
111111
if item_format == "file":
112112
# if we want to import file, create a spec that may be used only for importing
113113
best_writer_spec = create_import_spec(
114-
parsed_file_name.file_format, items_supported_file_formats # type: ignore[arg-type]
114+
parsed_file_name.file_format, items_supported_file_formats
115115
)
116116

117117
config_loader_file_format = config.loader_file_format

tests/load/test_sql_resource.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
# test the sql insert job loader, works only on duckdb for now
2+
3+
from typing import Any
4+
5+
import dlt
6+
7+
from dlt.common.destination.dataset import SupportsReadableDataset
8+
9+
from tests.pipeline.utils import load_table_counts
10+
11+
from dlt.extract.hints import make_hints
12+
13+
14+
def test_sql_job() -> None:
15+
# populate a table with 10 items and retrieve dataset
16+
pipeline = dlt.pipeline(
17+
pipeline_name="example_pipeline", destination="duckdb", dataset_name="example_dataset"
18+
)
19+
pipeline.run([{"a": i} for i in range(10)], table_name="example_table")
20+
dataset = pipeline.dataset()
21+
22+
# create a resource that generates sql statements to create 2 new tables
23+
@dlt.resource()
24+
def copied_table() -> Any:
25+
query = dataset["example_table"].limit(5).query()
26+
yield dlt.mark.with_hints(
27+
f"CREATE OR REPLACE TABLE copied_table AS {query}",
28+
make_hints(file_format="sql"),
29+
)
30+
31+
query = dataset["example_table"].limit(7).query()
32+
yield dlt.mark.with_hints(
33+
f"CREATE OR REPLACE TABLE copied_table2 AS {query}",
34+
make_hints(file_format="sql"),
35+
)
36+
37+
# run sql jobs
38+
pipeline.run(copied_table())
39+
40+
# the two tables where created
41+
assert load_table_counts(pipeline, "example_table", "copied_table", "copied_table2") == {
42+
"example_table": 10,
43+
"copied_table": 5,
44+
"copied_table2": 7,
45+
}
46+
47+
# we have a table entry for the main table "copied_table"
48+
assert "copied_table" in pipeline.default_schema.tables
49+
# but no columns, it's up to the user to provide a schema
50+
assert len(pipeline.default_schema.tables["copied_table"]["columns"]) == 0

0 commit comments

Comments
 (0)