Skip to content

Commit 10de2d4

Browse files
marashkaMarat Akhmetov
andauthored
[DOP-30298] support spark-dialect-extension 0.0.4
Co-authored-by: Marat Akhmetov <ahmetovmr@mts.ru>
1 parent c5beb92 commit 10de2d4

4 files changed

Lines changed: 55 additions & 3 deletions

File tree

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Updated ClickHouse ``spark-dialect-extension`` to 0.0.4.

syncmaster/worker/handlers/db/clickhouse.py

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,36 @@
33

44
from __future__ import annotations
55

6+
from datetime import UTC, datetime
7+
from decimal import Decimal
68
from typing import TYPE_CHECKING, ClassVar
79

810
from onetl.connection import Clickhouse
911
from onetl.db import DBWriter
1012
from onetl.hooks import slot, support_hooks
13+
from pyspark.sql import functions as F # noqa: N812
14+
from pyspark.sql.types import (
15+
ArrayType,
16+
BooleanType,
17+
ByteType,
18+
DataType,
19+
DateType,
20+
DecimalType,
21+
DoubleType,
22+
FloatType,
23+
IntegerType,
24+
LongType,
25+
MapType,
26+
ShortType,
27+
StringType,
28+
TimestampNTZType,
29+
TimestampType,
30+
)
1131

1232
from syncmaster.worker.handlers.db.base import DBHandler
1333

1434
if TYPE_CHECKING:
15-
from pyspark.sql import SparkSession
35+
from pyspark.sql import Column, SparkSession
1636
from pyspark.sql.dataframe import DataFrame
1737

1838
from syncmaster.dto.connections import ClickhouseConnectionDTO
@@ -55,6 +75,7 @@ def write(self, df: DataFrame) -> None:
5575
(col for col in normalized_df.columns if col.lower().endswith("id")),
5676
normalized_df.columns[0], # if there is no column with "id", take the first column
5777
)
78+
normalized_df = self._normalize_column_to_non_nullable(normalized_df, sort_column)
5879
self.transfer_dto.options["createTableOptions"] = (
5980
f"ENGINE = MergeTree() ORDER BY {self._quote_field(sort_column)}"
6081
)
@@ -74,6 +95,36 @@ def _normalize_column_names(self, df: DataFrame) -> DataFrame:
7495
df = df.withColumnRenamed(column_name, column_name.lower())
7596
return df
7697

98+
def _normalize_column_to_non_nullable(self, df: DataFrame, column: str) -> DataFrame:
99+
"""Rewrite ``column`` with ``coalesce`` so Spark marks it non-nullable.
100+
101+
TODO: remove this workaround if Spark infers ``nullable=false`` without a dummy literal:
102+
https://issues.apache.org/jira/browse/SPARK-54302
103+
"""
104+
field = df.schema[column]
105+
if not field.nullable:
106+
return df
107+
108+
field_type: DataType = field.dataType
109+
sentinel: Column
110+
if isinstance(field_type, (ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType, BooleanType)):
111+
sentinel = F.lit(0).cast(field_type)
112+
elif isinstance(field_type, DecimalType):
113+
sentinel = F.lit(Decimal(0)).cast(field_type)
114+
elif isinstance(field_type, StringType):
115+
sentinel = F.lit("")
116+
elif isinstance(field_type, (DateType, TimestampType, TimestampNTZType)):
117+
sentinel = F.lit(datetime(1970, 1, 1, 0, 0, 0, tzinfo=UTC)).cast(field_type)
118+
elif isinstance(field_type, ArrayType):
119+
sentinel = F.array().cast(field_type)
120+
elif isinstance(field_type, MapType):
121+
sentinel = F.map_from_arrays(F.array(), F.array()).cast(field_type)
122+
else:
123+
msg = f"Unsupported Spark type for non-null: {field_type!r} (column {column!r})"
124+
raise TypeError(msg)
125+
126+
return df.withColumn(column, F.coalesce(F.col(column), sentinel))
127+
77128
def _make_rows_filter_expression(self, filters: list[dict]) -> str | None:
78129
expressions = []
79130
for filter_ in filters:

syncmaster/worker/ivy2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def get_packages(connection_types: set[str]) -> list[str]:
3131
if connection_types & {"oracle", "all"}:
3232
result.extend(Oracle.get_packages())
3333
if connection_types & {"clickhouse", "all"}:
34-
result.append("io.github.mtsongithub.doetl:spark-dialect-extension_2.12:0.0.2")
34+
result.append("io.github.mtsongithub.doetl:spark-dialect-extension_2.12:0.0.4")
3535
result.extend(Clickhouse.get_packages())
3636
if connection_types & {"mssql", "all"}:
3737
result.extend(MSSQL.get_packages())

tests/test_integration/test_run_transfer/connection_fixtures/spark_fixtures.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def spark( # noqa: C901
5151
maven_packages.extend(Oracle.get_packages())
5252

5353
if "clickhouse" in markers:
54-
maven_packages.append("io.github.mtsongithub.doetl:spark-dialect-extension_2.12:0.0.2")
54+
maven_packages.append("io.github.mtsongithub.doetl:spark-dialect-extension_2.12:0.0.4")
5555
maven_packages.extend(Clickhouse.get_packages())
5656

5757
if "mssql" in markers:

0 commit comments

Comments
 (0)