diff --git a/pyspark_datasources/fake.py b/pyspark_datasources/fake.py index fd83900..2b9447c 100644 --- a/pyspark_datasources/fake.py +++ b/pyspark_datasources/fake.py @@ -1,5 +1,35 @@ -from pyspark.sql.datasource import DataSource, DataSourceReader -from pyspark.sql.types import StructType, StringType +from typing import List + +from pyspark.sql.datasource import ( + DataSource, + DataSourceReader, + DataSourceStreamReader, + InputPartition, +) +from pyspark.sql.types import StringType, StructType + + +def _validate_faker_schema(schema): + # Verify the library is installed correctly. + try: + from faker import Faker + except ImportError: + raise Exception("You need to install `faker` to use the fake datasource.") + + fake = Faker() + for field in schema.fields: + try: + getattr(fake, field.name)() + except AttributeError: + raise Exception( + f"Unable to find a method called `{field.name}` in faker. " + f"Please check Faker's documentation to see supported methods." + ) + if field.dataType != StringType(): + raise Exception( + f"Field `{field.name}` is not a StringType. " + f"Only StringType is supported in the fake datasource." + ) class FakeDataSource(DataSource): @@ -19,6 +49,7 @@ class FakeDataSource(DataSource): - The fake data source relies on the `faker` library. Make sure it is installed and accessible. - Only string type fields are supported, and each field name must correspond to a method name in the `faker` library. + - When using the stream reader, `numRows` is the number of rows per microbatch. Examples -------- @@ -61,6 +92,21 @@ class FakeDataSource(DataSource): | Caitlin Reed|1983-06-22| 89813|Pennsylvania| | Douglas James|2007-01-18| 46226| Alabama| +--------------+----------+-------+------------+ + + Streaming fake data: + + >>> stream = spark.readStream.format("fake").load().writeStream.format("console").start() + Batch: 0 + +--------------+----------+-------+------------+ + | name| date|zipcode| state| + +--------------+----------+-------+------------+ + | Tommy Diaz|1976-11-17| 27627|South Dakota| + |Jonathan Perez|1986-02-23| 81307|Rhode Island| + | Julia Farmer|1990-10-10| 40482| Virginia| + +--------------+----------+-------+------------+ + Batch: 1 + ... + >>> stream.stop() """ @classmethod @@ -70,40 +116,24 @@ def name(cls): def schema(self): return "name string, date string, zipcode string, state string" - def reader(self, schema: StructType): - # Verify the library is installed correctly. - try: - from faker import Faker - except ImportError: - raise Exception("You need to install `faker` to use the fake datasource.") - - # Check the schema is valid before proceed to reading. - fake = Faker() - for field in schema.fields: - try: - getattr(fake, field.name)() - except AttributeError: - raise Exception( - f"Unable to find a method called `{field.name}` in faker. " - f"Please check Faker's documentation to see supported methods." - ) - if field.dataType != StringType(): - raise Exception( - f"Field `{field.name}` is not a StringType. " - f"Only StringType is supported in the fake datasource." - ) - + def reader(self, schema: StructType) -> "FakeDataSourceReader": + _validate_faker_schema(schema) return FakeDataSourceReader(schema, self.options) + def streamReader(self, schema) -> "FakeDataSourceStreamReader": + _validate_faker_schema(schema) + return FakeDataSourceStreamReader(schema, self.options) + class FakeDataSourceReader(DataSourceReader): - def __init__(self, schema, options): + def __init__(self, schema, options) -> None: self.schema: StructType = schema self.options = options def read(self, partition): from faker import Faker + fake = Faker() # Note: every value in this `self.options` dictionary is a string. num_rows = int(self.options.get("numRows", 3)) @@ -113,3 +143,32 @@ def read(self, partition): value = getattr(fake, field.name)() row.append(value) yield tuple(row) + + +class FakeDataSourceStreamReader(DataSourceStreamReader): + def __init__(self, schema, options) -> None: + self.schema: StructType = schema + self.rows_per_microbatch = int(options.get("numRows", 3)) + self.options = options + self.offset = 0 + + def initialOffset(self) -> dict: + return {"offset": 0} + + def latestOffset(self) -> dict: + self.offset += self.rows_per_microbatch + return {"offset": self.offset} + + def partitions(self, start, end) -> List[InputPartition]: + return [InputPartition(end["offset"] - start["offset"])] + + def read(self, partition): + from faker import Faker + + fake = Faker() + for _ in range(partition.value): + row = [] + for field in self.schema.fields: + value = getattr(fake, field.name)() + row.append(value) + yield tuple(row) diff --git a/tests/test_data_sources.py b/tests/test_data_sources.py index 9ca74c1..1d337e2 100644 --- a/tests/test_data_sources.py +++ b/tests/test_data_sources.py @@ -17,6 +17,21 @@ def test_github_datasource(spark): assert len(prs) > 0 +def test_fake_datasource_stream(spark): + spark.dataSource.register(FakeDataSource) + ( + spark.readStream.format("fake") + .load() + .writeStream.format("memory") + .queryName("result") + .trigger(once=True) + .start() + .awaitTermination() + ) + spark.sql("SELECT * FROM result").show() + assert spark.sql("SELECT * FROM result").count() == 3 + + def test_fake_datasource(spark): spark.dataSource.register(FakeDataSource) df = spark.read.format("fake").load()