-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathtest_data_sources.py
40 lines (31 loc) · 965 Bytes
/
test_data_sources.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import pytest
from pyspark.sql import SparkSession
from pyspark_datasources import *
@pytest.fixture
def spark():
spark = SparkSession.builder.getOrCreate()
yield spark
def test_github_datasource(spark):
spark.dataSource.register(GithubDataSource)
df = spark.read.format("github").load("apache/spark")
prs = df.collect()
assert len(prs) > 0
def test_fake_datasource_stream(spark):
spark.dataSource.register(FakeDataSource)
(
spark.readStream.format("fake")
.load()
.writeStream.format("memory")
.queryName("result")
.trigger(once=True)
.start()
.awaitTermination()
)
spark.sql("SELECT * FROM result").show()
assert spark.sql("SELECT * FROM result").count() == 3
def test_fake_datasource(spark):
spark.dataSource.register(FakeDataSource)
df = spark.read.format("fake").load()
df.show()
assert df.count() == 3
assert len(df.columns) == 4