Skip to content

Commit 674ed48

Browse files
committed
[SPARK-54318][PYTHON][DOCS] Fix doctests in pyspark.sql.dataframe
### What changes were proposed in this pull request? Fix doctests in `pyspark.sql.dataframe` ### Why are the changes needed? to refine docstrings and improve test coverage ### Does this PR introduce _any_ user-facing change? yes, doc-only changes ### How was this patch tested? ci ### Was this patch authored or co-authored using generative AI tooling? no Closes #53013 from zhengruifeng/doctest_df. Authored-by: Ruifeng Zheng <[email protected]> Signed-off-by: Ruifeng Zheng <[email protected]>
1 parent d218805 commit 674ed48

File tree

1 file changed

+31
-28
lines changed

1 file changed

+31
-28
lines changed

python/pyspark/sql/dataframe.py

Lines changed: 31 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -2099,13 +2099,13 @@ def sample(
20992099
21002100
Examples
21012101
--------
2102-
>>> df = spark.range(10)
2102+
>>> df = spark.range(0, 10, 1, 1)
21032103
>>> df.sample(0.5, 3).count() # doctest: +SKIP
21042104
7
2105-
>>> df.sample(fraction=0.5, seed=3).count() # doctest: +SKIP
2106-
7
2107-
>>> df.sample(withReplacement=True, fraction=0.5, seed=3).count() # doctest: +SKIP
2108-
1
2105+
>>> df.sample(fraction=0.5, seed=3).count()
2106+
4
2107+
>>> df.sample(withReplacement=True, fraction=0.5, seed=3).count()
2108+
2
21092109
>>> df.sample(1.0).count()
21102110
10
21112111
>>> df.sample(fraction=1.0).count()
@@ -2187,8 +2187,8 @@ def sampleBy(
21872187
21882188
Examples
21892189
--------
2190-
>>> from pyspark.sql.functions import col
2191-
>>> dataset = spark.range(0, 100, 1, 5).select((col("id") % 3).alias("key"))
2190+
>>> from pyspark.sql import functions as sf
2191+
>>> dataset = spark.range(0, 100, 1, 5).select((sf.col("id") % 3).alias("key"))
21922192
>>> sampled = dataset.sampleBy("key", fractions={0: 0.1, 1: 0.2}, seed=0)
21932193
>>> sampled.groupBy("key").count().orderBy("key").show()
21942194
+---+-----+
@@ -2198,7 +2198,7 @@ def sampleBy(
21982198
| 1| 9|
21992199
+---+-----+
22002200
2201-
>>> dataset.sampleBy(col("key"), fractions={2: 1.0}, seed=0).count()
2201+
>>> dataset.sampleBy(sf.col("key"), fractions={2: 1.0}, seed=0).count()
22022202
33
22032203
"""
22042204
...
@@ -2315,9 +2315,9 @@ def columns(self) -> List[str]:
23152315
23162316
Example 4: Iterating over columns to apply a transformation
23172317
2318-
>>> import pyspark.sql.functions as f
2318+
>>> import pyspark.sql.functions as sf
23192319
>>> for col_name in df.columns:
2320-
... df = df.withColumn(col_name, f.upper(f.col(col_name)))
2320+
... df = df.withColumn(col_name, sf.upper(col_name))
23212321
>>> df.show()
23222322
+---+-----+-----+
23232323
|age| name|state|
@@ -2478,14 +2478,16 @@ def alias(self, alias: str) -> "DataFrame":
24782478
24792479
Examples
24802480
--------
2481-
>>> from pyspark.sql.functions import col, desc
2481+
>>> from pyspark.sql import functions as sf
24822482
>>> df = spark.createDataFrame(
24832483
... [(14, "Tom"), (23, "Alice"), (16, "Bob")], ["age", "name"])
24842484
>>> df_as1 = df.alias("df_as1")
24852485
>>> df_as2 = df.alias("df_as2")
2486-
>>> joined_df = df_as1.join(df_as2, col("df_as1.name") == col("df_as2.name"), 'inner')
2486+
>>> joined_df = df_as1.join(df_as2,
2487+
... sf.col("df_as1.name") == sf.col("df_as2.name"), 'inner')
24872488
>>> joined_df.select(
2488-
... "df_as1.name", "df_as2.name", "df_as2.age").sort(desc("df_as1.name")).show()
2489+
... "df_as1.name", "df_as2.name", "df_as2.age"
2490+
... ).sort(sf.desc("df_as1.name")).show()
24892491
+-----+-----+---+
24902492
| name| name|age|
24912493
+-----+-----+---+
@@ -2610,7 +2612,7 @@ def join(
26102612
they will appear with `NULL` in the `name` column of `df`, and vice versa for `df2`.
26112613
26122614
>>> joined = df.join(df2, df.name == df2.name, "outer").sort(sf.desc(df.name))
2613-
>>> joined.show() # doctest: +SKIP
2615+
>>> joined.show()
26142616
+-----+----+----+------+
26152617
| name| age|name|height|
26162618
+-----+----+----+------+
@@ -2621,7 +2623,7 @@ def join(
26212623
26222624
To unambiguously select output columns, specify the dataframe along with the column name:
26232625
2624-
>>> joined.select(df.name, df2.height).show() # doctest: +SKIP
2626+
>>> joined.select(df.name, df2.height).show()
26252627
+-----+------+
26262628
| name|height|
26272629
+-----+------+
@@ -4404,11 +4406,11 @@ def observe(
44044406
--------
44054407
When ``observation`` is :class:`Observation`, only batch queries work as below.
44064408
4407-
>>> from pyspark.sql.functions import col, count, lit, max
4408-
>>> from pyspark.sql import Observation
4409+
>>> from pyspark.sql import Observation, functions as sf
44094410
>>> df = spark.createDataFrame([(2, "Alice"), (5, "Bob")], schema=["age", "name"])
44104411
>>> observation = Observation("my metrics")
4411-
>>> observed_df = df.observe(observation, count(lit(1)).alias("count"), max(col("age")))
4412+
>>> observed_df = df.observe(observation,
4413+
... sf.count(sf.lit(1)).alias("count"), sf.max("age"))
44124414
>>> observed_df.count()
44134415
2
44144416
>>> observation.get
@@ -4441,13 +4443,13 @@ def observe(
44414443
>>> error_listener = MyErrorListener()
44424444
>>> spark.streams.addListener(error_listener)
44434445
>>> sdf = spark.readStream.format("rate").load().withColumn(
4444-
... "error", col("value")
4446+
... "error", sf.col("value")
44454447
... )
44464448
>>> # Observe row count (rc) and error row count (erc) in the streaming Dataset
44474449
... observed_ds = sdf.observe(
44484450
... "my_event",
4449-
... count(lit(1)).alias("rc"),
4450-
... count(col("error")).alias("erc"))
4451+
... sf.count(sf.lit(1)).alias("rc"),
4452+
... sf.count(sf.col("error")).alias("erc"))
44514453
>>> try:
44524454
... q = observed_ds.writeStream.format("console").start()
44534455
... time.sleep(5)
@@ -4512,11 +4514,11 @@ def union(self, other: "DataFrame") -> "DataFrame":
45124514
45134515
Example 2: Combining two DataFrames with different schemas
45144516
4515-
>>> from pyspark.sql.functions import lit
4517+
>>> from pyspark.sql import functions as sf
45164518
>>> df1 = spark.createDataFrame([(100001, 1), (100002, 2)], schema="id LONG, money INT")
45174519
>>> df2 = spark.createDataFrame([(3, 100003), (4, 100003)], schema="money INT, id LONG")
4518-
>>> df1 = df1.withColumn("age", lit(30))
4519-
>>> df2 = df2.withColumn("age", lit(40))
4520+
>>> df1 = df1.withColumn("age", sf.lit(30))
4521+
>>> df2 = df2.withColumn("age", sf.lit(40))
45204522
>>> df3 = df1.union(df2)
45214523
>>> df3.show()
45224524
+------+------+---+
@@ -6065,10 +6067,10 @@ def transform(self, func: Callable[..., "DataFrame"], *args: Any, **kwargs: Any)
60656067
60666068
Examples
60676069
--------
6068-
>>> from pyspark.sql.functions import col
6070+
>>> from pyspark.sql import functions as sf
60696071
>>> df = spark.createDataFrame([(1, 1.0), (2, 2.0)], ["int", "float"])
60706072
>>> def cast_all_to_int(input_df):
6071-
... return input_df.select([col(col_name).cast("int") for col_name in input_df.columns])
6073+
... return input_df.select([sf.col(c).cast("int") for c in input_df.columns])
60726074
...
60736075
>>> def sort_columns_asc(input_df):
60746076
... return input_df.select(*sorted(input_df.columns))
@@ -6082,8 +6084,9 @@ def transform(self, func: Callable[..., "DataFrame"], *args: Any, **kwargs: Any)
60826084
+-----+---+
60836085
60846086
>>> def add_n(input_df, n):
6085-
... return input_df.select([(col(col_name) + n).alias(col_name)
6086-
... for col_name in input_df.columns])
6087+
... cols = [(sf.col(c) + n).alias(c) for c in input_df.columns]
6088+
... return input_df.select(cols)
6089+
...
60876090
>>> df.transform(add_n, 1).transform(add_n, n=10).show()
60886091
+---+-----+
60896092
|int|float|

0 commit comments

Comments
 (0)