Skip to content

Commit

Permalink
[SPARK-43146][CONNECT][PYTHON] Implement eager evaluation for __repr_…
Browse files Browse the repository at this point in the history
…_ and _repr_html_

### What changes were proposed in this pull request?

Implements eager evaluation for `DataFrame.__repr__` and `DataFrame._repr_html_`.

### Why are the changes needed?

When `spark.sql.repl.eagerEval.enabled` is `True`, DataFrames should eagerly evaluate and show the results.

```py
>>> spark.conf.set('spark.sql.repl.eagerEval.enabled', True)
>>> spark.range(3)
+---+
| id|
+---+
|  0|
|  1|
|  2|
+---+
```

### Does this PR introduce _any_ user-facing change?

The eager evaluation will be available.

### How was this patch tested?

Enabled the related test.

Closes apache#40800 from ueshin/issues/SPARK-43146/eager_repr.

Authored-by: Takuya UESHIN <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
  • Loading branch information
ueshin authored and HyukjinKwon committed Apr 19, 2023
1 parent 1c6202c commit 780aeec
Show file tree
Hide file tree
Showing 10 changed files with 312 additions and 158 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ message Relation {
CoGroupMap co_group_map = 32;
WithWatermark with_watermark = 33;
ApplyInPandasWithState apply_in_pandas_with_state = 34;
HtmlString html_string = 35;

// NA functions
NAFill fill_na = 90;
Expand Down Expand Up @@ -457,6 +458,20 @@ message ShowString {
bool vertical = 4;
}

// Compose the string representing rows for output.
// It will invoke 'Dataset.htmlString' to compute the results.
message HtmlString {
// (Required) The input relation.
Relation input = 1;

// (Required) Number of rows to show.
int32 num_rows = 2;

// (Required) If set to more than 0, truncates strings to
// `truncate` characters and all cells will be aligned right.
int32 truncate = 3;
}

// Computes specified statistics for numeric and string columns.
// It will invoke 'Dataset.summary' (same as 'StatFunctions.summary')
// to compute the results.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ class SparkConnectPlanner(val session: SparkSession) {
val plan = rel.getRelTypeCase match {
// DataFrame API
case proto.Relation.RelTypeCase.SHOW_STRING => transformShowString(rel.getShowString)
case proto.Relation.RelTypeCase.HTML_STRING => transformHtmlString(rel.getHtmlString)
case proto.Relation.RelTypeCase.READ => transformReadRel(rel.getRead)
case proto.Relation.RelTypeCase.PROJECT => transformProject(rel.getProject)
case proto.Relation.RelTypeCase.FILTER => transformFilter(rel.getFilter)
Expand Down Expand Up @@ -225,6 +226,15 @@ class SparkConnectPlanner(val session: SparkSession) {
data = Tuple1.apply(showString) :: Nil)
}

private def transformHtmlString(rel: proto.HtmlString): LogicalPlan = {
val htmlString = Dataset
.ofRows(session, transformRelation(rel.getInput))
.htmlString(rel.getNumRows, rel.getTruncate)
LocalRelation.fromProduct(
output = AttributeReference("html_string", StringType, false)() :: Nil,
data = Tuple1.apply(htmlString) :: Nil)
}

private def transformSql(sql: proto.SQL): LogicalPlan = {
val args = sql.getArgsMap
val parser = session.sessionState.sqlParser
Expand Down
50 changes: 47 additions & 3 deletions python/pyspark/sql/connect/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,10 +96,57 @@ def __init__(
self._schema = schema
self._plan: Optional[plan.LogicalPlan] = None
self._session: "SparkSession" = session
# Check whether _repr_html is supported or not, we use it to avoid calling RPC twice
# by __repr__ and _repr_html_ while eager evaluation opens.
self._support_repr_html = False

def __repr__(self) -> str:
if not self._support_repr_html:
(
repl_eager_eval_enabled,
repl_eager_eval_max_num_rows,
repl_eager_eval_truncate,
) = self._session._get_configs(
"spark.sql.repl.eagerEval.enabled",
"spark.sql.repl.eagerEval.maxNumRows",
"spark.sql.repl.eagerEval.truncate",
)
if repl_eager_eval_enabled == "true":
return self._show_string(
n=int(cast(str, repl_eager_eval_max_num_rows)),
truncate=int(cast(str, repl_eager_eval_truncate)),
vertical=False,
)
return "DataFrame[%s]" % (", ".join("%s: %s" % c for c in self.dtypes))

def _repr_html_(self) -> Optional[str]:
if not self._support_repr_html:
self._support_repr_html = True
(
repl_eager_eval_enabled,
repl_eager_eval_max_num_rows,
repl_eager_eval_truncate,
) = self._session._get_configs(
"spark.sql.repl.eagerEval.enabled",
"spark.sql.repl.eagerEval.maxNumRows",
"spark.sql.repl.eagerEval.truncate",
)
if repl_eager_eval_enabled == "true":
pdf = DataFrame.withPlan(
plan.HtmlString(
child=self._plan,
num_rows=int(cast(str, repl_eager_eval_max_num_rows)),
truncate=int(cast(str, repl_eager_eval_truncate)),
),
session=self._session,
).toPandas()
assert pdf is not None
return pdf["html_string"][0]
else:
return None

_repr_html_.__doc__ = PySparkDataFrame._repr_html_.__doc__

@property
def write(self) -> "DataFrameWriter":
assert self._plan is not None
Expand Down Expand Up @@ -1827,9 +1874,6 @@ def writeStream(self) -> DataStreamWriter:
def toJSON(self, *args: Any, **kwargs: Any) -> None:
raise NotImplementedError("toJSON() is not implemented.")

def _repr_html_(self, *args: Any, **kwargs: Any) -> None:
raise NotImplementedError("_repr_html_() is not implemented.")

def sameSemantics(self, other: "DataFrame") -> bool:
assert self._plan is not None
assert other._plan is not None
Expand Down
15 changes: 15 additions & 0 deletions python/pyspark/sql/connect/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,21 @@ def plan(self, session: "SparkConnectClient") -> proto.Relation:
return plan


class HtmlString(LogicalPlan):
def __init__(self, child: Optional["LogicalPlan"], num_rows: int, truncate: int) -> None:
super().__init__(child)
self.num_rows = num_rows
self.truncate = truncate

def plan(self, session: "SparkConnectClient") -> proto.Relation:
assert self._child is not None
plan = self._create_proto_relation()
plan.html_string.input.CopyFrom(self._child.plan(session))
plan.html_string.num_rows = self.num_rows
plan.html_string.truncate = self.truncate
return plan


class Project(LogicalPlan):
"""Logical plan object for a projection.
Expand Down
262 changes: 138 additions & 124 deletions python/pyspark/sql/connect/proto/relations_pb2.py

Large diffs are not rendered by default.

47 changes: 47 additions & 0 deletions python/pyspark/sql/connect/proto/relations_pb2.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ class Relation(google.protobuf.message.Message):
CO_GROUP_MAP_FIELD_NUMBER: builtins.int
WITH_WATERMARK_FIELD_NUMBER: builtins.int
APPLY_IN_PANDAS_WITH_STATE_FIELD_NUMBER: builtins.int
HTML_STRING_FIELD_NUMBER: builtins.int
FILL_NA_FIELD_NUMBER: builtins.int
DROP_NA_FIELD_NUMBER: builtins.int
REPLACE_FIELD_NUMBER: builtins.int
Expand Down Expand Up @@ -179,6 +180,8 @@ class Relation(google.protobuf.message.Message):
@property
def apply_in_pandas_with_state(self) -> global___ApplyInPandasWithState: ...
@property
def html_string(self) -> global___HtmlString: ...
@property
def fill_na(self) -> global___NAFill:
"""NA functions"""
@property
Expand Down Expand Up @@ -249,6 +252,7 @@ class Relation(google.protobuf.message.Message):
co_group_map: global___CoGroupMap | None = ...,
with_watermark: global___WithWatermark | None = ...,
apply_in_pandas_with_state: global___ApplyInPandasWithState | None = ...,
html_string: global___HtmlString | None = ...,
fill_na: global___NAFill | None = ...,
drop_na: global___NADrop | None = ...,
replace: global___NAReplace | None = ...,
Expand Down Expand Up @@ -307,6 +311,8 @@ class Relation(google.protobuf.message.Message):
b"group_map",
"hint",
b"hint",
"html_string",
b"html_string",
"join",
b"join",
"limit",
Expand Down Expand Up @@ -410,6 +416,8 @@ class Relation(google.protobuf.message.Message):
b"group_map",
"hint",
b"hint",
"html_string",
b"html_string",
"join",
b"join",
"limit",
Expand Down Expand Up @@ -506,6 +514,7 @@ class Relation(google.protobuf.message.Message):
"co_group_map",
"with_watermark",
"apply_in_pandas_with_state",
"html_string",
"fill_na",
"drop_na",
"replace",
Expand Down Expand Up @@ -1813,6 +1822,44 @@ class ShowString(google.protobuf.message.Message):

global___ShowString = ShowString

class HtmlString(google.protobuf.message.Message):
"""Compose the string representing rows for output.
It will invoke 'Dataset.htmlString' to compute the results.
"""

DESCRIPTOR: google.protobuf.descriptor.Descriptor

INPUT_FIELD_NUMBER: builtins.int
NUM_ROWS_FIELD_NUMBER: builtins.int
TRUNCATE_FIELD_NUMBER: builtins.int
@property
def input(self) -> global___Relation:
"""(Required) The input relation."""
num_rows: builtins.int
"""(Required) Number of rows to show."""
truncate: builtins.int
"""(Required) If set to more than 0, truncates strings to
`truncate` characters and all cells will be aligned right.
"""
def __init__(
self,
*,
input: global___Relation | None = ...,
num_rows: builtins.int = ...,
truncate: builtins.int = ...,
) -> None: ...
def HasField(
self, field_name: typing_extensions.Literal["input", b"input"]
) -> builtins.bool: ...
def ClearField(
self,
field_name: typing_extensions.Literal[
"input", b"input", "num_rows", b"num_rows", "truncate", b"truncate"
],
) -> None: ...

global___HtmlString = HtmlString

class StatSummary(google.protobuf.message.Message):
"""Computes specified statistics for numeric and string columns.
It will invoke 'Dataset.summary' (same as 'StatFunctions.summary')
Expand Down
27 changes: 2 additions & 25 deletions python/pyspark/sql/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import warnings
from collections.abc import Iterable
from functools import reduce
from html import escape as html_escape
from typing import (
Any,
Callable,
Expand Down Expand Up @@ -936,32 +935,10 @@ def _repr_html_(self) -> Optional[str]:
if not self._support_repr_html:
self._support_repr_html = True
if self.sparkSession._jconf.isReplEagerEvalEnabled():
max_num_rows = max(self.sparkSession._jconf.replEagerEvalMaxNumRows(), 0)
sock_info = self._jdf.getRowsToPython(
max_num_rows,
return self._jdf.htmlString(
self.sparkSession._jconf.replEagerEvalMaxNumRows(),
self.sparkSession._jconf.replEagerEvalTruncate(),
)
rows = list(_load_from_socket(sock_info, BatchedSerializer(CPickleSerializer())))
head = rows[0]
row_data = rows[1:]
has_more_data = len(row_data) > max_num_rows
row_data = row_data[:max_num_rows]

html = "<table border='1'>\n"
# generate table head
html += "<tr><th>%s</th></tr>\n" % "</th><th>".join(map(lambda x: html_escape(x), head))
# generate table rows
for row in row_data:
html += "<tr><td>%s</td></tr>\n" % "</td><td>".join(
map(lambda x: html_escape(x), row)
)
html += "</table>\n"
if has_more_data:
html += "only showing top %d %s\n" % (
max_num_rows,
"row" if max_num_rows == 1 else "rows",
)
return html
else:
return None

Expand Down
1 change: 0 additions & 1 deletion python/pyspark/sql/tests/connect/test_connect_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2923,7 +2923,6 @@ def test_unsupported_functions(self):
"foreachPartition",
"checkpoint",
"localCheckpoint",
"_repr_html_",
):
with self.assertRaises(NotImplementedError):
getattr(df, f)()
Expand Down
5 changes: 0 additions & 5 deletions python/pyspark/sql/tests/connect/test_parity_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,6 @@ def test_pandas_api(self):
def test_repartitionByRange_dataframe(self):
super().test_repartitionByRange_dataframe()

# TODO(SPARK-41834): Implement SparkSession.conf
@unittest.skip("Fails in Spark Connect, should enable.")
def test_repr_behaviors(self):
super().test_repr_behaviors()

@unittest.skip("Spark Connect does not SparkContext but the tests depend on them.")
def test_same_semantics_error(self):
super().test_same_semantics_error()
Expand Down
38 changes: 38 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import scala.reflect.runtime.universe.TypeTag
import scala.util.control.NonFatal

import org.apache.commons.lang3.StringUtils
import org.apache.commons.text.StringEscapeUtils

import org.apache.spark.TaskContext
import org.apache.spark.annotation.{DeveloperApi, Stable, Unstable}
Expand Down Expand Up @@ -402,6 +403,43 @@ class Dataset[T] private[sql](
sb.toString()
}

/**
* Compose the HTML representing rows for output
*
* @param _numRows Number of rows to show
* @param truncate If set to more than 0, truncates strings to `truncate` characters and
* all cells will be aligned right.
*/
private[sql] def htmlString(
_numRows: Int,
truncate: Int = 20): String = {
val numRows = _numRows.max(0).min(ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH - 1)
// Get rows represented by Seq[Seq[String]], we may get one more line if it has more data.
val tmpRows = getRows(numRows, truncate)

val hasMoreData = tmpRows.length - 1 > numRows
val rows = tmpRows.take(numRows + 1)

val sb = new StringBuilder

sb.append("<table border='1'>\n")

sb.append(rows.head.map(StringEscapeUtils.escapeHtml4)
.mkString("<tr><th>", "</th><th>", "</th></tr>\n"))
rows.tail.foreach { row =>
sb.append(row.map(StringEscapeUtils.escapeHtml4)
.mkString("<tr><td>", "</td><td>", "</td></tr>\n"))
}

sb.append("</table>\n")

if (hasMoreData) {
sb.append(s"only showing top $numRows ${if (numRows == 1) "row" else "rows"}\n")
}

sb.toString()
}

override def toString: String = {
try {
val builder = new StringBuilder
Expand Down

0 comments on commit 780aeec

Please sign in to comment.