Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added metrics calculation on scala UDFs #50

Open
wants to merge 12 commits into
base: sb-main
Choose a base branch
from
22 changes: 21 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ lib/
lib64
parts/
sdist/
src/
src/*
var/
wheels/
pip-wheel-metadata/
Expand Down Expand Up @@ -148,4 +148,24 @@ tmp/
docs/_build/
*.DS_Store
*/catboost_info

# Spark files
metastore_db/

.vscode

# VSCode scala exstention files
.metals
.bloop
.bsp

# meld
*.orig

# supplementary files
rsync-repo.sh
requirements.txt
airflow.yaml

# temporary
experiments/tests
Binary file added jars/replay_2.12-0.1.jar
Binary file not shown.
40 changes: 39 additions & 1 deletion replay/metrics/base_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,17 @@
from typing import Dict, List, Tuple, Union, Optional

import pandas as pd
from pyspark.sql import Column
from pyspark.sql import DataFrame
from pyspark.sql import functions as sf
from pyspark.sql import types as st
from pyspark.sql.types import DataType
from pyspark.sql import Window
from pyspark.sql.column import _to_java_column, _to_seq
from scipy.stats import norm

from replay.constants import AnyDataFrame, IntOrList, NumType
from replay.session_handler import State
from replay.utils import convert2spark, get_top_k_recs


Expand Down Expand Up @@ -162,6 +165,10 @@ class Metric(ABC):
"""Base metric class"""

_logger: Optional[logging.Logger] = None
_scala_udf_name: Optional[str] = None

def __init__(self, use_scala_udf: bool = False) -> None:
self._use_scala_udf = use_scala_udf

@property
def logger(self) -> logging.Logger:
Expand All @@ -172,6 +179,14 @@ def logger(self) -> logging.Logger:
self._logger = logging.getLogger("replay")
return self._logger

@property
def scala_udf_name(self) -> str:
"""Returns UDF name from `org.apache.spark.replay.utils.ScalaPySparkUDFs`"""
if self._scala_udf_name:
return self._scala_udf_name
else:
raise NotImplementedError(f"Scala UDF not implemented for {type(self).__name__} class!")

def __str__(self):
return type(self).__name__

Expand Down Expand Up @@ -254,6 +269,12 @@ def _get_metric_distribution(self, recs: DataFrame, k: int) -> DataFrame:
:param k: depth cut-off
:return: metric distribution for different cut-offs and users
"""
if self._use_scala_udf:
metric_value_col = self.get_scala_udf(
self.scala_udf_name, [sf.lit(k).alias("k"), *recs.columns[1:]]
).alias("value")
return recs.select("user_idx", metric_value_col)

cur_class = self.__class__
distribution = recs.rdd.flatMap(
# pylint: disable=protected-access
Expand Down Expand Up @@ -333,13 +354,28 @@ def user_distribution(
res = res.append(val, ignore_index=True)
return res

@staticmethod
def get_scala_udf(udf_name: str, params: List) -> Column:
"""
Returns expression of calling scala UDF as column

:param udf_name: UDF name from `org.apache.spark.replay.utils.ScalaPySparkUDFs`
:param params: list of UDF params in right order
:return: column expression
"""
sc = State().session.sparkContext # pylint: disable=invalid-name
scala_udf = getattr(
sc._jvm.org.apache.spark.replay.utils.ScalaPySparkUDFs, udf_name
)()
return Column(scala_udf.apply(_to_seq(sc, params, _to_java_column)))


# pylint: disable=too-few-public-methods
class RecOnlyMetric(Metric):
"""Base class for metrics that do not need holdout data"""

@abstractmethod
def __init__(self, log: AnyDataFrame, *args, **kwargs):
def __init__(self, log: AnyDataFrame, *args, **kwargs): # pylint: disable=super-init-not-called
pass

# pylint: disable=no-self-use
Expand Down Expand Up @@ -402,6 +438,7 @@ def __init__(
prev_policy_weights: AnyDataFrame,
threshold: float = 10.0,
activation: Optional[str] = None,
use_scala_udf: bool = False,
): # pylint: disable=super-init-not-called
"""
:param prev_policy_weights: historical item of user-item relevance (previous policy values)
Expand All @@ -410,6 +447,7 @@ def __init__(
:activation: activation function, applied over relevance values.
"logit"/"sigmoid", "softmax" or None
"""
self._use_scala_udf = use_scala_udf
self.prev_policy_weights = convert2spark(
prev_policy_weights
).withColumnRenamed("relevance", "prev_relevance")
Expand Down
2 changes: 2 additions & 0 deletions replay/metrics/hitrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ class HitRate(Metric):

"""

_scala_udf_name = "getHitRateMetricValue"

@staticmethod
def _get_metric_value_by_user(k, pred, ground_truth) -> float:
for i in pred[:k]:
Expand Down
2 changes: 2 additions & 0 deletions replay/metrics/map.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ class MAP(Metric):
:math:`\\mathbb{1}_{r_{ij}}` -- indicator function showing if user :math:`i` interacted with item :math:`j`
"""

_scala_udf_name = "getMAPMetricValue"

@staticmethod
def _get_metric_value_by_user(k, pred, ground_truth) -> float:
length = min(k, len(pred))
Expand Down
2 changes: 2 additions & 0 deletions replay/metrics/mrr.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ class MRR(Metric):
1.0
"""

_scala_udf_name = "getMRRMetricValue"

@staticmethod
def _get_metric_value_by_user(k, pred, ground_truth) -> float:
for i in range(min(k, len(pred))):
Expand Down
2 changes: 2 additions & 0 deletions replay/metrics/ncis_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ class NCISPrecision(NCISMetric):
Source: arxiv.org/abs/1801.07030
"""

_scala_udf_name = "getNCISPrecisionMetricValue"

@staticmethod
def _get_metric_value_by_user(k, *args):
pred, ground_truth, pred_weights = args
Expand Down
2 changes: 2 additions & 0 deletions replay/metrics/ndcg.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ class NDCG(Metric):
0.5
"""

_scala_udf_name = "getNDCGMetricValue"

@staticmethod
def _get_metric_value_by_user(k, pred, ground_truth) -> float:
if len(pred) == 0 or len(ground_truth) == 0:
Expand Down
2 changes: 2 additions & 0 deletions replay/metrics/precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ class Precision(Metric):

:math:`\\mathbb{1}_{r_{ij}}` -- indicator function showing that user :math:`i` interacted with item :math:`j`"""

_scala_udf_name = "getPrecisionMetricValue"

@staticmethod
def _get_metric_value_by_user(k, pred, ground_truth) -> float:
if len(pred) == 0:
Expand Down
2 changes: 2 additions & 0 deletions replay/metrics/recall.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ class Recall(Metric):
:math:`|Rel_i|` -- the number of relevant items for user :math:`i`
"""

_scala_udf_name = "getRecallMetricValue"

@staticmethod
def _get_metric_value_by_user(k, pred, ground_truth) -> float:
if len(ground_truth) == 0:
Expand Down
2 changes: 2 additions & 0 deletions replay/metrics/rocauc.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ class RocAuc(Metric):

"""

_scala_udf_name = "getRocAucMetricValue"

@staticmethod
def _get_metric_value_by_user(k, pred, ground_truth) -> float:
length = min(k, len(pred))
Expand Down
8 changes: 6 additions & 2 deletions replay/metrics/surprisal.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@
from pyspark.sql import types as st

from replay.constants import AnyDataFrame
from replay.utils import convert2spark, get_top_k_recs
from replay.metrics.base_metric import (
fill_na_with_empty_array,
RecOnlyMetric,
sorter,
)
from replay.utils import convert2spark, get_top_k_recs


# pylint: disable=too-few-public-methods
Expand Down Expand Up @@ -45,14 +45,18 @@ class Surprisal(RecOnlyMetric):
Surprisal@K = \\frac {\sum_{i=1}^{N}Surprisal@K(i)}{N}
"""

_scala_udf_name = "getSurprisalMetricValue"

def __init__(
self, log: AnyDataFrame
self, log: AnyDataFrame,
use_scala_udf: bool = False
): # pylint: disable=super-init-not-called
"""
Here we calculate self-information for each item

:param log: historical data
"""
self._use_scala_udf = use_scala_udf
self.log = convert2spark(log)
n_users = self.log.select("user_idx").distinct().count() # type: ignore
self.item_weights = self.log.groupby("item_idx").agg(
Expand Down
9 changes: 7 additions & 2 deletions replay/metrics/unexpectedness.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
from typing import Optional

from pyspark.sql import DataFrame
from pyspark.sql import functions as sf
from pyspark.sql import types as st

from replay.constants import AnyDataFrame
from replay.utils import convert2spark, get_top_k_recs
from replay.metrics.base_metric import (
RecOnlyMetric,
sorter,
fill_na_with_empty_array,
)
from replay.utils import convert2spark, get_top_k_recs


# pylint: disable=too-few-public-methods
Expand All @@ -29,12 +30,16 @@ class Unexpectedness(RecOnlyMetric):
0.67
"""

_scala_udf_name = "getUnexpectednessMetricValue"

def __init__(
self, pred: AnyDataFrame
self, pred: AnyDataFrame,
use_scala_udf: bool = False
): # pylint: disable=super-init-not-called
"""
:param pred: model predictions
"""
self._use_scala_udf = use_scala_udf
self.pred = convert2spark(pred)

@staticmethod
Expand Down
1 change: 1 addition & 0 deletions replay/session_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def get_spark_session(
"spark.driver.extraJavaOptions",
"-Dio.netty.tryReflectionSetAccessible=true",
)
.config("spark.jars", os.environ.get("REPLAY_JAR_PATH", "jars/replay_2.12-0.1.jar"))
.config("spark.sql.shuffle.partitions", str(shuffle_partitions))
.config("spark.local.dir", os.path.join(user_home, "tmp"))
.config("spark.driver.maxResultSize", "4g")
Expand Down
9 changes: 9 additions & 0 deletions scala/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
.idea
.bloop
.bsp
.metals
.vscode
target
project/target
work
assembly
20 changes: 20 additions & 0 deletions scala/build.sbt
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import sbt.Keys.resolvers

name := "replay"

version := "0.1"

scalaVersion := "2.12.15"

// idePackagePrefix := Some("org.apache.spark.ml.feature.lightautoml")

resolvers ++= Seq(
("Confluent" at "http://packages.confluent.io/maven")
.withAllowInsecureProtocol(true)
)

libraryDependencies ++= Seq(
"org.apache.spark" %% "spark-core" % "3.1.3",
"org.apache.spark" %% "spark-sql" % "3.1.3",
"org.apache.spark" %% "spark-mllib" % "3.1.3",
)
1 change: 1 addition & 0 deletions scala/project/build.properties
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
sbt.version=1.7.1
Loading