33
44from __future__ import annotations
55
6+ from datetime import UTC , datetime
7+ from decimal import Decimal
68from typing import TYPE_CHECKING , ClassVar
79
810from onetl .connection import Clickhouse
911from onetl .db import DBWriter
1012from onetl .hooks import slot , support_hooks
13+ from pyspark .sql import functions as F # noqa: N812
14+ from pyspark .sql .types import (
15+ ArrayType ,
16+ BooleanType ,
17+ ByteType ,
18+ DataType ,
19+ DateType ,
20+ DecimalType ,
21+ DoubleType ,
22+ FloatType ,
23+ IntegerType ,
24+ LongType ,
25+ MapType ,
26+ ShortType ,
27+ StringType ,
28+ TimestampNTZType ,
29+ TimestampType ,
30+ )
1131
1232from syncmaster .worker .handlers .db .base import DBHandler
1333
1434if TYPE_CHECKING :
15- from pyspark .sql import SparkSession
35+ from pyspark .sql import Column , SparkSession
1636 from pyspark .sql .dataframe import DataFrame
1737
1838 from syncmaster .dto .connections import ClickhouseConnectionDTO
@@ -55,6 +75,7 @@ def write(self, df: DataFrame) -> None:
5575 (col for col in normalized_df .columns if col .lower ().endswith ("id" )),
5676 normalized_df .columns [0 ], # if there is no column with "id", take the first column
5777 )
78+ normalized_df = self ._normalize_column_to_non_nullable (normalized_df , sort_column )
5879 self .transfer_dto .options ["createTableOptions" ] = (
5980 f"ENGINE = MergeTree() ORDER BY { self ._quote_field (sort_column )} "
6081 )
@@ -74,6 +95,36 @@ def _normalize_column_names(self, df: DataFrame) -> DataFrame:
7495 df = df .withColumnRenamed (column_name , column_name .lower ())
7596 return df
7697
98+ def _normalize_column_to_non_nullable (self , df : DataFrame , column : str ) -> DataFrame :
99+ """Rewrite ``column`` with ``coalesce`` so Spark marks it non-nullable.
100+
101+ TODO: remove this workaround if Spark infers ``nullable=false`` without a dummy literal:
102+ https://issues.apache.org/jira/browse/SPARK-54302
103+ """
104+ field = df .schema [column ]
105+ if not field .nullable :
106+ return df
107+
108+ field_type : DataType = field .dataType
109+ sentinel : Column
110+ if isinstance (field_type , (ByteType , ShortType , IntegerType , LongType , FloatType , DoubleType , BooleanType )):
111+ sentinel = F .lit (0 ).cast (field_type )
112+ elif isinstance (field_type , DecimalType ):
113+ sentinel = F .lit (Decimal (0 )).cast (field_type )
114+ elif isinstance (field_type , StringType ):
115+ sentinel = F .lit ("" )
116+ elif isinstance (field_type , (DateType , TimestampType , TimestampNTZType )):
117+ sentinel = F .lit (datetime (1970 , 1 , 1 , 0 , 0 , 0 , tzinfo = UTC )).cast (field_type )
118+ elif isinstance (field_type , ArrayType ):
119+ sentinel = F .array ().cast (field_type )
120+ elif isinstance (field_type , MapType ):
121+ sentinel = F .map_from_arrays (F .array (), F .array ()).cast (field_type )
122+ else :
123+ msg = f"Unsupported Spark type for non-null: { field_type !r} (column { column !r} )"
124+ raise TypeError (msg )
125+
126+ return df .withColumn (column , F .coalesce (F .col (column ), sentinel ))
127+
77128 def _make_rows_filter_expression (self , filters : list [dict ]) -> str | None :
78129 expressions = []
79130 for filter_ in filters :
0 commit comments