Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

JoinBasedIndexer, DataPreparator and DataframeBucketizer #49

Open
wants to merge 2 commits into
base: sb-main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
331 changes: 327 additions & 4 deletions replay/data_preparator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,17 @@
``ToNumericFeatureTransformer`` leaves only numerical features
by one-hot encoding of some features and deleting the others.
"""
import json
import logging
import string
from typing import Dict, List, Optional
from functools import singledispatchmethod
from os.path import join
from typing import Dict, List, Optional, overload, Any

from pyspark.ml import Transformer, Estimator
from pyspark.ml.feature import StringIndexerModel, IndexToString, StringIndexer
from pyspark.ml.param import Param, Params
from pyspark.ml.util import MLWriter, MLWritable, MLReader, MLReadable, DefaultParamsWriter, DefaultParamsReader
from pyspark.sql import DataFrame
from pyspark.sql import functions as sf
from pyspark.sql.types import DoubleType, NumericType
Expand Down Expand Up @@ -186,7 +192,242 @@ def _reindex(self, df: DataFrame, entity: str):
inv_indexer.setLabels(new_labels)


class DataPreparator:
# We need to inherit it from DefaultParamsWriter to make it being saved correctly within Pipeline
class JoinIndexerMLWriter(DefaultParamsWriter):
"""Implements saving the JoinIndexerTransformer instance to disk.
Used when saving a trained pipeline.
Implements MLWriter.saveImpl(path) method.
"""

def __init__(self, instance):
super().__init__(instance)
self.instance = instance

def saveImpl(self, path: str) -> None:
super().saveImpl(path)
# print(f"Saving {type(self.instance).__name__} to '{path}'")

spark = State().session

init_args = self.instance._init_args
sc = spark.sparkContext
df = spark.read.json(sc.parallelize([json.dumps(init_args)]))
df.coalesce(1).write.mode("overwrite").json(join(path, "init_args.json"))

self.instance.user_col_2_index_map.write.mode("overwrite").save(join(path, "user_col_2_index_map.parquet"))
self.instance.item_col_2_index_map.write.mode("overwrite").save(join(path, "item_col_2_index_map.parquet"))


class JoinIndexerMLReader(MLReader):
def load(self, path):
"""Load the ML instance from the input path."""
spark = State().session
args = spark.read.json(join(path, "init_args.json")).first().asDict(recursive=True)
user_col_2_index_map = spark.read.parquet(join(path, "user_col_2_index_map.parquet"))
item_col_2_index_map = spark.read.parquet(join(path, "item_col_2_index_map.parquet"))

indexer = JoinBasedIndexerTransformer(
user_col=args["user_col"],
user_type=args["user_type"],
user_col_2_index_map=user_col_2_index_map,
item_col=args["item_col"],
item_type=args["item_type"],
item_col_2_index_map=item_col_2_index_map,

)

return indexer


class JoinBasedIndexerTransformer(Transformer, MLWritable, MLReadable):
def __init__(
self,
user_col: str,
item_col: str,
user_type: str,
item_type: str,
user_col_2_index_map: DataFrame,
item_col_2_index_map: DataFrame,
update_map_on_transform: bool = False,
force_broadcast_on_mapping_joins: bool = True
):
super().__init__()
self.user_col = user_col
self.item_col = item_col
self.user_type = user_type
self.item_type = item_type
self.user_col_2_index_map = user_col_2_index_map
self.item_col_2_index_map = item_col_2_index_map
self.update_map_on_transform = update_map_on_transform
self.force_broadcast_on_mapping_joins = force_broadcast_on_mapping_joins

@property
def _init_args(self):
return {
"user_col": self.user_col,
"item_col": self.item_col,
"user_type": self.user_type,
"item_type": self.item_type,
"update_map_on_transform": self.update_map_on_transform,
"force_broadcast_on_mapping_joins": self.force_broadcast_on_mapping_joins
}

def set_update_map_on_transform(self, value: bool):
"""Sets 'update_map_on_transform' flag"""
self.update_map_on_transform = value

def set_force_broadcast_on_mapping_joins(self, value: bool):
"""Sets 'force_broadcast_on_mapping_joins' flag"""
self.force_broadcast_on_mapping_joins = value

def _get_item_mapping(self) -> DataFrame:
if self.force_broadcast_on_mapping_joins:
mapping = sf.broadcast(self.item_col_2_index_map)
else:
mapping = self.item_col_2_index_map
return mapping

def _get_user_mapping(self) -> DataFrame:
if self.force_broadcast_on_mapping_joins:
mapping = sf.broadcast(self.user_col_2_index_map)
else:
mapping = self.user_col_2_index_map
return mapping

def write(self) -> MLWriter:
"""Returns MLWriter instance that can save the Transformer instance."""
return JoinIndexerMLWriter(self)

@classmethod
def read(cls):
"""Returns an MLReader instance for this class."""
return JoinIndexerMLReader()

def _update_maps(self, df: DataFrame):

new_items = (
df.join(self._get_item_mapping(), on=self.item_col, how="left_anti")
.select(self.item_col).distinct()
)
prev_item_count = self.item_col_2_index_map.count()
new_items_map = (
JoinBasedIndexerEstimator.get_map(new_items, self.item_col, "item_idx")
.select(self.item_col, (sf.col("item_idx") + prev_item_count).alias("item_idx"))
)
self.item_col_2_index_map = self.item_col_2_index_map.union(new_items_map)

new_users = (
df.join(self._get_user_mapping(), on=self.user_col, how="left_anti")
.select(self.user_col).distinct()
)
prev_user_count = self.user_col_2_index_map.count()
new_users_map = (
JoinBasedIndexerEstimator.get_map(new_users, self.user_col, "user_idx")
.select(self.user_col, (sf.col("user_idx") + prev_user_count).alias("user_idx"))
)
self.user_col_2_index_map = self.user_col_2_index_map.union(new_users_map)

def _transform(self, df: DataFrame) -> DataFrame:

if self.update_map_on_transform:
self._update_maps(df)

if self.item_col in df.columns:
remaining_cols = df.drop(self.item_col).columns
df = df.join(self._get_item_mapping(), on=self.item_col, how="left").select(
sf.col("item_idx").cast("int").alias("item_idx"),
*remaining_cols,
)
if self.user_col in df.columns:
remaining_cols = df.drop(self.user_col).columns
df = df.join(self._get_user_mapping(), on=self.user_col, how="left").select(
sf.col("user_idx").cast("int").alias("user_idx"),
*remaining_cols,
)
return df

def inverse_transform(self, df: DataFrame) -> DataFrame:
"""
Convert DataFrame to the initial indexes.

:param df: DataFrame with numerical ``user_idx/item_idx`` columns
:return: DataFrame with original user/item columns
"""
if "item_idx" in df.columns:
remaining_cols = df.drop("item_idx").columns
df = df.join(self._get_item_mapping(), on="item_idx", how="left").select(
sf.col(self.item_col).cast(self.item_type).alias(self.item_col),
*remaining_cols,
)
if "user_idx" in df.columns:
remaining_cols = df.drop("user_idx").columns
df = df.join(self._get_user_mapping(), on="user_idx", how="left").select(
sf.col(self.user_col).cast(self.user_type).alias(self.user_col),
*remaining_cols,
)
return df


class JoinBasedIndexerEstimator(Estimator):
def __init__(self, user_col="user_id", item_col="item_id"):
"""
Provide column names for indexer to use
"""
self.user_col = user_col
self.item_col = item_col
self.user_col_2_index_map = None
self.item_col_2_index_map = None

@staticmethod
def get_map(df: DataFrame, col_name: str, idx_col_name: str) -> DataFrame:
uid_rdd = (
df.select(col_name).distinct()
.rdd.map(lambda x: x[col_name])
.zipWithIndex()
)

spark = State().session
_map = spark.createDataFrame(uid_rdd, [col_name, idx_col_name])
return _map

def _fit(self, df: DataFrame) -> Transformer:
"""
Creates indexers to map raw id to numerical idx so that spark can handle them.
:param df: DataFrame containing user column and item column
:return:
"""

self.user_col_2_index_map = self.get_map(df, self.user_col, "user_idx")
self.item_col_2_index_map = self.get_map(df, self.item_col, "item_idx")

self.user_type = df.schema[
self.user_col
].dataType
self.item_type = df.schema[
self.item_col
].dataType

return JoinBasedIndexerTransformer(
user_col=self.user_col,
user_type=str(self.user_type),
item_col=self.item_col,
item_type=str(self.item_type),
user_col_2_index_map=self.user_col_2_index_map,
item_col_2_index_map=self.item_col_2_index_map
)


class DataPreparatorWriter(DefaultParamsWriter):
def __init__(self, instance: 'DataPreparator'):
super().__init__(instance)


class DataPreparatorReader(DefaultParamsReader):
def __init__(self, cls):
super().__init__(cls)


class DataPreparator(Transformer, MLWritable, MLReadable):
"""Transforms data to a library format:
- read as a spark dataframe/ convert pandas dataframe to spark
- check for nulls
Expand Down Expand Up @@ -246,9 +487,27 @@ class DataPreparator:
<BLANKLINE>

"""
columnsMapping = Param(Params._dummy(), "columnsMapping", "columns mapping")

_logger: Optional[logging.Logger] = None

def __init__(self, columns_mapping: Optional[Dict[str, str]] = None):
super().__init__()
self.setColumnsMapping(columns_mapping)

def getColumnsMapping(self):
return self.getOrDefault(self.columnsMapping)

def setColumnsMapping(self, value):
self.set(self.columnsMapping, value)

def write(self) -> MLWriter:
return DataPreparatorWriter(self)

@classmethod
def read(cls) -> MLReader:
return DataPreparatorReader(cls)

@property
def logger(self) -> logging.Logger:
"""
Expand Down Expand Up @@ -397,10 +656,74 @@ def _rename(df: DataFrame, mapping: Dict) -> Optional[DataFrame]:
df = df.withColumnRenamed(in_col, out_col)
return df

@overload
def transform(self, dataset: DataFrame, params: Optional[Dict[Param, Any]] = None):
"""
:param dataset: DataFrame to process
:param params: A dict with settings to be applied for dataset processing
:return: processed DataFrame
"""
...

# noinspection PyMethodOverriding
@overload
def transform(self,
columns_mapping: Dict[str, str],
data: Optional[AnyDataFrame],
path: Optional[str],
format_type: Optional[str],
date_format: Optional[str],
reader_kwargs: Optional[Dict]) -> DataFrame:
"""
:param columns_mapping: dictionary mapping "key: column name in input DataFrame".
Possible keys: ``[user_id, user_id, timestamp, relevance]``
``columns_mapping`` values specifies the nature of the DataFrame:
- if both ``[user_id, item_id]`` are present,
then the dataframe is a log of interactions.
Specify ``timestamp, relevance`` columns in mapping if present.
- if ether ``user_id`` or ``item_id`` is present,
then the dataframe is a dataframe of user/item features

:param data: DataFrame to process
:param path: path to data
:param format_type: file type, one of ``[csv , parquet , json , table]``
:param date_format: format for the ``timestamp`` column
:param reader_kwargs: extra arguments passed to
``spark.read.<format>(path, **reader_kwargs)``
:return: processed DataFrame
"""
...

def transform(self, *args, **kwargs):
"""
Transforms log, user or item features into a Spark DataFrame
``[user_id, user_id, timestamp, relevance]``,
``[user_id, *features]``, or ``[item_id, *features]``.
Input is either file of ``format_type``
at ``path``, or ``pandas.DataFrame`` or ``spark.DataFrame``.
Transform performs:
- dataframe reading/convert to spark DataFrame format
- check dataframe (nulls, columns_mapping)
- rename columns from mapping to standard names (user_id, user_id, timestamp, relevance)
- for interactions log: create absent columns,
convert ``timestamp`` column to TimestampType and ``relevance`` to DoubleType


"""
return self._do_transform(*args, **kwargs)

@singledispatchmethod
def _do_transform(self, dataset: DataFrame, params: Optional[Dict[Param, Any]] = None):
return super().transform(dataset, params)

def _transform(self, dataset):
return self.transform(self.getColumnsMapping(), data=dataset)

# pylint: disable=too-many-arguments
def transform(
@_do_transform.register
def _(
self,
columns_mapping: Dict[str, str],
columns_mapping: dict,
data: Optional[AnyDataFrame] = None,
path: Optional[str] = None,
format_type: Optional[str] = None,
Expand Down
Loading