diff --git a/replay/data_preparator.py b/replay/data_preparator.py index feeac0cc3..a1c8e9e54 100644 --- a/replay/data_preparator.py +++ b/replay/data_preparator.py @@ -6,11 +6,17 @@ ``ToNumericFeatureTransformer`` leaves only numerical features by one-hot encoding of some features and deleting the others. """ +import json import logging import string -from typing import Dict, List, Optional +from functools import singledispatchmethod +from os.path import join +from typing import Dict, List, Optional, overload, Any +from pyspark.ml import Transformer, Estimator from pyspark.ml.feature import StringIndexerModel, IndexToString, StringIndexer +from pyspark.ml.param import Param, Params +from pyspark.ml.util import MLWriter, MLWritable, MLReader, MLReadable, DefaultParamsWriter, DefaultParamsReader from pyspark.sql import DataFrame from pyspark.sql import functions as sf from pyspark.sql.types import DoubleType, NumericType @@ -186,7 +192,242 @@ def _reindex(self, df: DataFrame, entity: str): inv_indexer.setLabels(new_labels) -class DataPreparator: +# We need to inherit it from DefaultParamsWriter to make it being saved correctly within Pipeline +class JoinIndexerMLWriter(DefaultParamsWriter): + """Implements saving the JoinIndexerTransformer instance to disk. + Used when saving a trained pipeline. + Implements MLWriter.saveImpl(path) method. + """ + + def __init__(self, instance): + super().__init__(instance) + self.instance = instance + + def saveImpl(self, path: str) -> None: + super().saveImpl(path) + # print(f"Saving {type(self.instance).__name__} to '{path}'") + + spark = State().session + + init_args = self.instance._init_args + sc = spark.sparkContext + df = spark.read.json(sc.parallelize([json.dumps(init_args)])) + df.coalesce(1).write.mode("overwrite").json(join(path, "init_args.json")) + + self.instance.user_col_2_index_map.write.mode("overwrite").save(join(path, "user_col_2_index_map.parquet")) + self.instance.item_col_2_index_map.write.mode("overwrite").save(join(path, "item_col_2_index_map.parquet")) + + +class JoinIndexerMLReader(MLReader): + def load(self, path): + """Load the ML instance from the input path.""" + spark = State().session + args = spark.read.json(join(path, "init_args.json")).first().asDict(recursive=True) + user_col_2_index_map = spark.read.parquet(join(path, "user_col_2_index_map.parquet")) + item_col_2_index_map = spark.read.parquet(join(path, "item_col_2_index_map.parquet")) + + indexer = JoinBasedIndexerTransformer( + user_col=args["user_col"], + user_type=args["user_type"], + user_col_2_index_map=user_col_2_index_map, + item_col=args["item_col"], + item_type=args["item_type"], + item_col_2_index_map=item_col_2_index_map, + + ) + + return indexer + + +class JoinBasedIndexerTransformer(Transformer, MLWritable, MLReadable): + def __init__( + self, + user_col: str, + item_col: str, + user_type: str, + item_type: str, + user_col_2_index_map: DataFrame, + item_col_2_index_map: DataFrame, + update_map_on_transform: bool = False, + force_broadcast_on_mapping_joins: bool = True + ): + super().__init__() + self.user_col = user_col + self.item_col = item_col + self.user_type = user_type + self.item_type = item_type + self.user_col_2_index_map = user_col_2_index_map + self.item_col_2_index_map = item_col_2_index_map + self.update_map_on_transform = update_map_on_transform + self.force_broadcast_on_mapping_joins = force_broadcast_on_mapping_joins + + @property + def _init_args(self): + return { + "user_col": self.user_col, + "item_col": self.item_col, + "user_type": self.user_type, + "item_type": self.item_type, + "update_map_on_transform": self.update_map_on_transform, + "force_broadcast_on_mapping_joins": self.force_broadcast_on_mapping_joins + } + + def set_update_map_on_transform(self, value: bool): + """Sets 'update_map_on_transform' flag""" + self.update_map_on_transform = value + + def set_force_broadcast_on_mapping_joins(self, value: bool): + """Sets 'force_broadcast_on_mapping_joins' flag""" + self.force_broadcast_on_mapping_joins = value + + def _get_item_mapping(self) -> DataFrame: + if self.force_broadcast_on_mapping_joins: + mapping = sf.broadcast(self.item_col_2_index_map) + else: + mapping = self.item_col_2_index_map + return mapping + + def _get_user_mapping(self) -> DataFrame: + if self.force_broadcast_on_mapping_joins: + mapping = sf.broadcast(self.user_col_2_index_map) + else: + mapping = self.user_col_2_index_map + return mapping + + def write(self) -> MLWriter: + """Returns MLWriter instance that can save the Transformer instance.""" + return JoinIndexerMLWriter(self) + + @classmethod + def read(cls): + """Returns an MLReader instance for this class.""" + return JoinIndexerMLReader() + + def _update_maps(self, df: DataFrame): + + new_items = ( + df.join(self._get_item_mapping(), on=self.item_col, how="left_anti") + .select(self.item_col).distinct() + ) + prev_item_count = self.item_col_2_index_map.count() + new_items_map = ( + JoinBasedIndexerEstimator.get_map(new_items, self.item_col, "item_idx") + .select(self.item_col, (sf.col("item_idx") + prev_item_count).alias("item_idx")) + ) + self.item_col_2_index_map = self.item_col_2_index_map.union(new_items_map) + + new_users = ( + df.join(self._get_user_mapping(), on=self.user_col, how="left_anti") + .select(self.user_col).distinct() + ) + prev_user_count = self.user_col_2_index_map.count() + new_users_map = ( + JoinBasedIndexerEstimator.get_map(new_users, self.user_col, "user_idx") + .select(self.user_col, (sf.col("user_idx") + prev_user_count).alias("user_idx")) + ) + self.user_col_2_index_map = self.user_col_2_index_map.union(new_users_map) + + def _transform(self, df: DataFrame) -> DataFrame: + + if self.update_map_on_transform: + self._update_maps(df) + + if self.item_col in df.columns: + remaining_cols = df.drop(self.item_col).columns + df = df.join(self._get_item_mapping(), on=self.item_col, how="left").select( + sf.col("item_idx").cast("int").alias("item_idx"), + *remaining_cols, + ) + if self.user_col in df.columns: + remaining_cols = df.drop(self.user_col).columns + df = df.join(self._get_user_mapping(), on=self.user_col, how="left").select( + sf.col("user_idx").cast("int").alias("user_idx"), + *remaining_cols, + ) + return df + + def inverse_transform(self, df: DataFrame) -> DataFrame: + """ + Convert DataFrame to the initial indexes. + + :param df: DataFrame with numerical ``user_idx/item_idx`` columns + :return: DataFrame with original user/item columns + """ + if "item_idx" in df.columns: + remaining_cols = df.drop("item_idx").columns + df = df.join(self._get_item_mapping(), on="item_idx", how="left").select( + sf.col(self.item_col).cast(self.item_type).alias(self.item_col), + *remaining_cols, + ) + if "user_idx" in df.columns: + remaining_cols = df.drop("user_idx").columns + df = df.join(self._get_user_mapping(), on="user_idx", how="left").select( + sf.col(self.user_col).cast(self.user_type).alias(self.user_col), + *remaining_cols, + ) + return df + + +class JoinBasedIndexerEstimator(Estimator): + def __init__(self, user_col="user_id", item_col="item_id"): + """ + Provide column names for indexer to use + """ + self.user_col = user_col + self.item_col = item_col + self.user_col_2_index_map = None + self.item_col_2_index_map = None + + @staticmethod + def get_map(df: DataFrame, col_name: str, idx_col_name: str) -> DataFrame: + uid_rdd = ( + df.select(col_name).distinct() + .rdd.map(lambda x: x[col_name]) + .zipWithIndex() + ) + + spark = State().session + _map = spark.createDataFrame(uid_rdd, [col_name, idx_col_name]) + return _map + + def _fit(self, df: DataFrame) -> Transformer: + """ + Creates indexers to map raw id to numerical idx so that spark can handle them. + :param df: DataFrame containing user column and item column + :return: + """ + + self.user_col_2_index_map = self.get_map(df, self.user_col, "user_idx") + self.item_col_2_index_map = self.get_map(df, self.item_col, "item_idx") + + self.user_type = df.schema[ + self.user_col + ].dataType + self.item_type = df.schema[ + self.item_col + ].dataType + + return JoinBasedIndexerTransformer( + user_col=self.user_col, + user_type=str(self.user_type), + item_col=self.item_col, + item_type=str(self.item_type), + user_col_2_index_map=self.user_col_2_index_map, + item_col_2_index_map=self.item_col_2_index_map + ) + + +class DataPreparatorWriter(DefaultParamsWriter): + def __init__(self, instance: 'DataPreparator'): + super().__init__(instance) + + +class DataPreparatorReader(DefaultParamsReader): + def __init__(self, cls): + super().__init__(cls) + + +class DataPreparator(Transformer, MLWritable, MLReadable): """Transforms data to a library format: - read as a spark dataframe/ convert pandas dataframe to spark - check for nulls @@ -246,9 +487,27 @@ class DataPreparator: """ + columnsMapping = Param(Params._dummy(), "columnsMapping", "columns mapping") _logger: Optional[logging.Logger] = None + def __init__(self, columns_mapping: Optional[Dict[str, str]] = None): + super().__init__() + self.setColumnsMapping(columns_mapping) + + def getColumnsMapping(self): + return self.getOrDefault(self.columnsMapping) + + def setColumnsMapping(self, value): + self.set(self.columnsMapping, value) + + def write(self) -> MLWriter: + return DataPreparatorWriter(self) + + @classmethod + def read(cls) -> MLReader: + return DataPreparatorReader(cls) + @property def logger(self) -> logging.Logger: """ @@ -397,10 +656,74 @@ def _rename(df: DataFrame, mapping: Dict) -> Optional[DataFrame]: df = df.withColumnRenamed(in_col, out_col) return df + @overload + def transform(self, dataset: DataFrame, params: Optional[Dict[Param, Any]] = None): + """ + :param dataset: DataFrame to process + :param params: A dict with settings to be applied for dataset processing + :return: processed DataFrame + """ + ... + + # noinspection PyMethodOverriding + @overload + def transform(self, + columns_mapping: Dict[str, str], + data: Optional[AnyDataFrame], + path: Optional[str], + format_type: Optional[str], + date_format: Optional[str], + reader_kwargs: Optional[Dict]) -> DataFrame: + """ + :param columns_mapping: dictionary mapping "key: column name in input DataFrame". + Possible keys: ``[user_id, user_id, timestamp, relevance]`` + ``columns_mapping`` values specifies the nature of the DataFrame: + - if both ``[user_id, item_id]`` are present, + then the dataframe is a log of interactions. + Specify ``timestamp, relevance`` columns in mapping if present. + - if ether ``user_id`` or ``item_id`` is present, + then the dataframe is a dataframe of user/item features + + :param data: DataFrame to process + :param path: path to data + :param format_type: file type, one of ``[csv , parquet , json , table]`` + :param date_format: format for the ``timestamp`` column + :param reader_kwargs: extra arguments passed to + ``spark.read.(path, **reader_kwargs)`` + :return: processed DataFrame + """ + ... + + def transform(self, *args, **kwargs): + """ + Transforms log, user or item features into a Spark DataFrame + ``[user_id, user_id, timestamp, relevance]``, + ``[user_id, *features]``, or ``[item_id, *features]``. + Input is either file of ``format_type`` + at ``path``, or ``pandas.DataFrame`` or ``spark.DataFrame``. + Transform performs: + - dataframe reading/convert to spark DataFrame format + - check dataframe (nulls, columns_mapping) + - rename columns from mapping to standard names (user_id, user_id, timestamp, relevance) + - for interactions log: create absent columns, + convert ``timestamp`` column to TimestampType and ``relevance`` to DoubleType + + + """ + return self._do_transform(*args, **kwargs) + + @singledispatchmethod + def _do_transform(self, dataset: DataFrame, params: Optional[Dict[Param, Any]] = None): + return super().transform(dataset, params) + + def _transform(self, dataset): + return self.transform(self.getColumnsMapping(), data=dataset) + # pylint: disable=too-many-arguments - def transform( + @_do_transform.register + def _( self, - columns_mapping: Dict[str, str], + columns_mapping: dict, data: Optional[AnyDataFrame] = None, path: Optional[str] = None, format_type: Optional[str] = None, diff --git a/replay/dataframe_bucketizer.py b/replay/dataframe_bucketizer.py new file mode 100644 index 000000000..7da1205cf --- /dev/null +++ b/replay/dataframe_bucketizer.py @@ -0,0 +1,99 @@ +from pyspark.ml import Transformer +from pyspark.ml.param import TypeConverters, Params, Param +from pyspark.ml.util import DefaultParamsWritable, DefaultParamsReadable +from pyspark.sql import DataFrame + +from replay.session_handler import State + + +class DataframeBucketizer(Transformer, DefaultParamsWritable, DefaultParamsReadable): + """ + Buckets the input dataframe, dumps it to spark warehouse directory, and returns a bucketed dataframe. + """ + + bucketingKey = Param( + Params._dummy(), + "bucketingKey", + "bucketing key (also used as sort key)", + typeConverter=TypeConverters.toString, + ) + + partitionNum = Param( + Params._dummy(), + "partitionNum", + "number of buckets", + typeConverter=TypeConverters.toInt, + ) + + tableName = Param( + Params._dummy(), + "tableName", + "parquet file name (for storage in 'spark-warehouse') and spark table name", + typeConverter=TypeConverters.toString, + ) + + sparkWarehouseDir = Param( + Params._dummy(), + "sparkWarehouseDir", + "sparkWarehouseDir", + typeConverter=TypeConverters.toString, + ) + + def __init__(self, bucketing_key: str, partition_num: int, spark_warehouse_dir: str, table_name: str = ""): + """Makes bucketed dataframe from input dataframe. + + Args: + bucketing_key: bucketing key (also used as sort key) + partition_num: number of buckets + table_name: parquet file name (for storage in 'spark-warehouse') and spark table name + spark_warehouse_dir: spark warehouse dir, i.e. value of 'spark.sql.warehouse.dir' property + """ + super().__init__() + self.set(self.bucketingKey, bucketing_key) + self.set(self.partitionNum, partition_num) + self.set(self.tableName, table_name) + self.set(self.sparkWarehouseDir, spark_warehouse_dir) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + self.remove_parquet() + + def remove_parquet(self): + spark = State().session + spark_warehouse_dir = self.getOrDefault(self.sparkWarehouseDir) + table_name = self.getOrDefault(self.tableName) + fs = spark._jvm.org.apache.hadoop.fs.FileSystem.get(spark._jsc.hadoopConfiguration()) + fs_path = spark._jvm.org.apache.hadoop.fs.Path(f"{spark_warehouse_dir}/{table_name}") + is_exists = fs.exists(fs_path) + if is_exists: + fs.delete(fs_path, True) + + def set_table_name(self, table_name: str): + self.set(self.tableName, table_name) + + def _transform(self, df: DataFrame): + bucketing_key = self.getOrDefault(self.bucketingKey) + partition_num = self.getOrDefault(self.partitionNum) + table_name = self.getOrDefault(self.tableName) + spark_warehouse_dir = self.getOrDefault(self.sparkWarehouseDir) + + if not table_name: + raise ValueError("Parameter 'table_name' is not set! Please set it via method 'set_table_name'.") + + ( + df.repartition(partition_num, bucketing_key) + .write.mode("overwrite") + .bucketBy(partition_num, bucketing_key) + .sortBy(bucketing_key) + .saveAsTable( + table_name, + format="parquet", + path=f"{spark_warehouse_dir}/{table_name}", + ) + ) + + spark = State().session + + return spark.table(table_name)