Skip to content

Add fake streaming source #8

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 85 additions & 26 deletions pyspark_datasources/fake.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,35 @@
from pyspark.sql.datasource import DataSource, DataSourceReader
from pyspark.sql.types import StructType, StringType
from typing import List

from pyspark.sql.datasource import (
DataSource,
DataSourceReader,
DataSourceStreamReader,
InputPartition,
)
from pyspark.sql.types import StringType, StructType


def _validate_faker_schema(schema):
# Verify the library is installed correctly.
try:
from faker import Faker
except ImportError:
raise Exception("You need to install `faker` to use the fake datasource.")

fake = Faker()
for field in schema.fields:
try:
getattr(fake, field.name)()
except AttributeError:
raise Exception(
f"Unable to find a method called `{field.name}` in faker. "
f"Please check Faker's documentation to see supported methods."
)
if field.dataType != StringType():
raise Exception(
f"Field `{field.name}` is not a StringType. "
f"Only StringType is supported in the fake datasource."
)


class FakeDataSource(DataSource):
Expand All @@ -19,6 +49,7 @@ class FakeDataSource(DataSource):
- The fake data source relies on the `faker` library. Make sure it is installed and accessible.
- Only string type fields are supported, and each field name must correspond to a method name in
the `faker` library.
- When using the stream reader, `numRows` is the number of rows per microbatch.

Examples
--------
Expand Down Expand Up @@ -61,6 +92,21 @@ class FakeDataSource(DataSource):
| Caitlin Reed|1983-06-22| 89813|Pennsylvania|
| Douglas James|2007-01-18| 46226| Alabama|
+--------------+----------+-------+------------+

Streaming fake data:

>>> stream = spark.readStream.format("fake").load().writeStream.format("console").start()
Batch: 0
+--------------+----------+-------+------------+
| name| date|zipcode| state|
+--------------+----------+-------+------------+
| Tommy Diaz|1976-11-17| 27627|South Dakota|
|Jonathan Perez|1986-02-23| 81307|Rhode Island|
| Julia Farmer|1990-10-10| 40482| Virginia|
+--------------+----------+-------+------------+
Batch: 1
...
>>> stream.stop()
"""

@classmethod
Expand All @@ -70,40 +116,24 @@ def name(cls):
def schema(self):
return "name string, date string, zipcode string, state string"

def reader(self, schema: StructType):
# Verify the library is installed correctly.
try:
from faker import Faker
except ImportError:
raise Exception("You need to install `faker` to use the fake datasource.")

# Check the schema is valid before proceed to reading.
fake = Faker()
for field in schema.fields:
try:
getattr(fake, field.name)()
except AttributeError:
raise Exception(
f"Unable to find a method called `{field.name}` in faker. "
f"Please check Faker's documentation to see supported methods."
)
if field.dataType != StringType():
raise Exception(
f"Field `{field.name}` is not a StringType. "
f"Only StringType is supported in the fake datasource."
)

def reader(self, schema: StructType) -> "FakeDataSourceReader":
_validate_faker_schema(schema)
return FakeDataSourceReader(schema, self.options)

def streamReader(self, schema) -> "FakeDataSourceStreamReader":
_validate_faker_schema(schema)
return FakeDataSourceStreamReader(schema, self.options)


class FakeDataSourceReader(DataSourceReader):

def __init__(self, schema, options):
def __init__(self, schema, options) -> None:
self.schema: StructType = schema
self.options = options

def read(self, partition):
from faker import Faker

fake = Faker()
# Note: every value in this `self.options` dictionary is a string.
num_rows = int(self.options.get("numRows", 3))
Expand All @@ -113,3 +143,32 @@ def read(self, partition):
value = getattr(fake, field.name)()
row.append(value)
yield tuple(row)


class FakeDataSourceStreamReader(DataSourceStreamReader):
def __init__(self, schema, options) -> None:
self.schema: StructType = schema
self.rows_per_microbatch = int(options.get("numRows", 3))
self.options = options
self.offset = 0

def initialOffset(self) -> dict:
return {"offset": 0}

def latestOffset(self) -> dict:
self.offset += self.rows_per_microbatch
return {"offset": self.offset}

def partitions(self, start, end) -> List[InputPartition]:
return [InputPartition(end["offset"] - start["offset"])]

def read(self, partition):
from faker import Faker

fake = Faker()
for _ in range(partition.value):
row = []
for field in self.schema.fields:
value = getattr(fake, field.name)()
row.append(value)
yield tuple(row)
15 changes: 15 additions & 0 deletions tests/test_data_sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,21 @@ def test_github_datasource(spark):
assert len(prs) > 0


def test_fake_datasource_stream(spark):
spark.dataSource.register(FakeDataSource)
(
spark.readStream.format("fake")
.load()
.writeStream.format("memory")
.queryName("result")
.trigger(once=True)
.start()
.awaitTermination()
)
spark.sql("SELECT * FROM result").show()
assert spark.sql("SELECT * FROM result").count() == 3


def test_fake_datasource(spark):
spark.dataSource.register(FakeDataSource)
df = spark.read.format("fake").load()
Expand Down