18
18
import time
19
19
import unittest
20
20
21
- from pyspark .sql import Row
22
- from pyspark .sql .functions import col , lit , count , sum , mean
21
+ from pyspark .sql import Row , Observation , functions as F
23
22
from pyspark .errors import (
24
23
PySparkAssertionError ,
25
24
PySparkTypeError ,
26
25
PySparkValueError ,
27
26
)
28
27
from pyspark .testing .sqlutils import ReusedSQLTestCase
28
+ from pyspark .testing .utils import assertDataFrameEqual
29
29
30
30
31
31
class DataFrameObservationTestsMixin :
32
32
def test_observe (self ):
33
33
# SPARK-36263: tests the DataFrame.observe(Observation, *Column) method
34
- from pyspark .sql import Observation
35
-
36
34
df = self .spark .createDataFrame (
37
35
[
38
36
(1 , 1.0 , "one" ),
@@ -58,11 +56,11 @@ def test_observe(self):
58
56
df .orderBy ("id" )
59
57
.observe (
60
58
named_observation ,
61
- count (lit (1 )).alias ("cnt" ),
62
- sum (col ("id" )).alias ("sum" ),
63
- mean (col ("val" )).alias ("mean" ),
59
+ F . count (F . lit (1 )).alias ("cnt" ),
60
+ F . sum (F . col ("id" )).alias ("sum" ),
61
+ F . mean (F . col ("val" )).alias ("mean" ),
64
62
)
65
- .observe (unnamed_observation , count (lit (1 )).alias ("rows" ))
63
+ .observe (unnamed_observation , F . count (F . lit (1 )).alias ("rows" ))
66
64
)
67
65
68
66
# test that observe works transparently
@@ -81,7 +79,7 @@ def test_observe(self):
81
79
self .assertEqual (unnamed_observation .get , dict (rows = 3 ))
82
80
83
81
with self .assertRaises (PySparkAssertionError ) as pe :
84
- df .observe (named_observation , count (lit (1 )).alias ("count" ))
82
+ df .observe (named_observation , F . count (F . lit (1 )).alias ("count" ))
85
83
86
84
self .check_error (
87
85
exception = pe .exception ,
@@ -106,7 +104,7 @@ def test_observe(self):
106
104
)
107
105
108
106
# dataframe.observe requires non-None Columns
109
- for args in [(None ,), ("id" ,), (lit (1 ), None ), (lit (1 ), "id" )]:
107
+ for args in [(None ,), ("id" ,), (F . lit (1 ), None ), (F . lit (1 ), "id" )]:
110
108
with self .subTest (args = args ):
111
109
with self .assertRaises (PySparkTypeError ) as pe :
112
110
df .observe (Observation (), * args )
@@ -140,7 +138,9 @@ def onQueryTerminated(self, event):
140
138
self .spark .streams .addListener (TestListener ())
141
139
142
140
df = self .spark .readStream .format ("rate" ).option ("rowsPerSecond" , 10 ).load ()
143
- df = df .observe ("metric" , count (lit (1 )).alias ("cnt" ), sum (col ("value" )).alias ("sum" ))
141
+ df = df .observe (
142
+ "metric" , F .count (F .lit (1 )).alias ("cnt" ), F .sum (F .col ("value" )).alias ("sum" )
143
+ )
144
144
q = df .writeStream .format ("noop" ).queryName ("test" ).start ()
145
145
self .assertTrue (q .isActive )
146
146
time .sleep (10 )
@@ -157,15 +157,13 @@ def onQueryTerminated(self, event):
157
157
158
158
def test_observe_with_same_name_on_different_dataframe (self ):
159
159
# SPARK-45656: named observations with the same name on different datasets
160
- from pyspark .sql import Observation
161
-
162
160
observation1 = Observation ("named" )
163
161
df1 = self .spark .range (50 )
164
- observed_df1 = df1 .observe (observation1 , count (lit (1 )).alias ("cnt" ))
162
+ observed_df1 = df1 .observe (observation1 , F . count (F . lit (1 )).alias ("cnt" ))
165
163
166
164
observation2 = Observation ("named" )
167
165
df2 = self .spark .range (100 )
168
- observed_df2 = df2 .observe (observation2 , count (lit (1 )).alias ("cnt" ))
166
+ observed_df2 = df2 .observe (observation2 , F . count (F . lit (1 )).alias ("cnt" ))
169
167
170
168
observed_df1 .collect ()
171
169
observed_df2 .collect ()
@@ -174,8 +172,6 @@ def test_observe_with_same_name_on_different_dataframe(self):
174
172
self .assertEqual (observation2 .get , dict (cnt = 100 ))
175
173
176
174
def test_observe_on_commands (self ):
177
- from pyspark .sql import Observation
178
-
179
175
df = self .spark .range (50 )
180
176
181
177
test_table = "test_table"
@@ -190,10 +186,46 @@ def test_observe_on_commands(self):
190
186
]:
191
187
with self .subTest (command = command ):
192
188
observation = Observation ()
193
- observed_df = df .observe (observation , count (lit (1 )).alias ("cnt" ))
189
+ observed_df = df .observe (observation , F . count (F . lit (1 )).alias ("cnt" ))
194
190
action (observed_df )
195
191
self .assertEqual (observation .get , dict (cnt = 50 ))
196
192
193
+ def test_observe_with_struct_type (self ):
194
+ observation = Observation ("struct" )
195
+
196
+ df = self .spark .range (10 ).observe (
197
+ observation ,
198
+ F .struct (F .count (F .lit (1 )).alias ("rows" ), F .max ("id" ).alias ("maxid" )).alias ("struct" ),
199
+ )
200
+
201
+ assertDataFrameEqual (df , [Row (id = id ) for id in range (10 )])
202
+
203
+ self .assertEqual (observation .get , {"struct" : Row (rows = 10 , maxid = 9 )})
204
+
205
+ def test_observe_with_array_type (self ):
206
+ observation = Observation ("array" )
207
+
208
+ df = self .spark .range (10 ).observe (
209
+ observation ,
210
+ F .array (F .count (F .lit (1 ))).alias ("array" ),
211
+ )
212
+
213
+ assertDataFrameEqual (df , [Row (id = id ) for id in range (10 )])
214
+
215
+ self .assertEqual (observation .get , {"array" : [10 ]})
216
+
217
+ def test_observe_with_map_type (self ):
218
+ observation = Observation ("map" )
219
+
220
+ df = self .spark .range (10 ).observe (
221
+ observation ,
222
+ F .create_map (F .lit ("count" ), F .count (F .lit (1 ))).alias ("map" ),
223
+ )
224
+
225
+ assertDataFrameEqual (df , [Row (id = id ) for id in range (10 )])
226
+
227
+ self .assertEqual (observation .get , {"map" : {"count" : 10 }})
228
+
197
229
198
230
class DataFrameObservationTests (
199
231
DataFrameObservationTestsMixin ,
0 commit comments