Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Basic Arrow.jl-based collect and createDataFrame #115

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ uuid = "e3819d11-95af-5eea-9727-70c091663a01"
version = "0.6.1"

[deps]
Arrow = "69666777-d1a9-59fb-9406-91d4454c9d45"
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
IteratorInterfaceExtensions = "82899510-4779-5014-852e-03e436cf321d"
JavaCall = "494afd89-becb-516b-aafa-70d2670c0337"
Expand All @@ -11,6 +13,7 @@ Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
Sockets = "6462fe0b-24de-5631-8697-dd941f90decc"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
TableTraits = "3783bdb8-4a98-5b6b-af9a-565f29a5fe9c"
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
Umlaut = "92992a2b-8ce5-4a9c-bb9d-58be9a7dc841"

[compat]
Expand All @@ -22,8 +25,7 @@ Umlaut = "0.2"
julia = "1.6"

[extras]
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test", "DataFrames"]
test = ["Test"]
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
package org.apache.spark.sql.julia

import org.apache.arrow.vector.VectorSchemaRoot
import org.apache.arrow.vector.ipc.message.{IpcOption, MessageSerializer}
import org.apache.arrow.vector.ipc.{ArrowFileReader, ArrowFileWriter, ArrowStreamReader, ArrowStreamWriter, SeekableReadChannel, WriteChannel}
import org.apache.arrow.vector.types.pojo.Schema
import org.apache.arrow.vector.util.ByteArrayReadableSeekableByteChannel
import org.apache.commons.io.output.ByteArrayOutputStream
import org.apache.spark.TaskContext
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder}
import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection, UnsafeRow}
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}
import org.apache.spark.sql.execution.{ExplainMode, SQLExecution}
import org.apache.spark.sql.execution.arrow.{ArrowConverters, ArrowWriter}
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.util.ArrowUtils
import org.apache.spark.util.Utils

import java.io.{File, FileInputStream}
import java.nio.channels.FileChannel
import java.nio.file.{OpenOption, Path, Paths, StandardOpenOption}

object DatasetUtils {
/** Based on a row iterator and Spark's ArrowWriter */
def collectToArrow1[R](df: Dataset[R], tempFilePath: String): Unit = {
// Get rows Iterator
// Can't use df.collectToIterator() because we need InternalRow to be able to use Spark's ArrowWriter
val rows = SQLExecution.withNewExecutionId(df.queryExecution, Some("collectToArrow")) {
df.queryExecution.executedPlan.resetMetrics()
df.queryExecution.executedPlan.executeToIterator()
}

val timeZone = df.sqlContext.conf.sessionLocalTimeZone
val arrowSchema = ArrowUtils.toArrowSchema(df.schema, timeZone)

val allocator = ArrowUtils.rootAllocator.newChildAllocator(s"Julia collectToArrow", 0, Long.MaxValue)
val root = VectorSchemaRoot.create(arrowSchema, allocator)

try {
Utils.tryWithResource(FileChannel.open(Paths.get(tempFilePath), StandardOpenOption.WRITE)) { tempFile =>
val arrowWriter = ArrowWriter.create(root)
val writer = new ArrowFileWriter(root, null, tempFile)
writer.start()

for (row <- rows) {
arrowWriter.write(row)
}
arrowWriter.finish()
writer.writeBatch()
arrowWriter.reset()
writer.end()
}
} finally {
root.close()
allocator.close()
}
}

private def writeArrowSchema(schema: StructType, timeZone: String, out: WriteChannel): Unit = {
val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZone)
MessageSerializer.serialize(out, arrowSchema)
}

private def iterateRdd[T: scala.reflect.ClassTag](rdd: RDD[T], preserveOrder: Boolean, f: T => Unit): Unit = {
if (preserveOrder) {
// toLocalIterator has the disadvantage of running a job for each partition, one after the other
// so it might be much slower for small datasets with many partitions
// easy fix for the user is to use coalesce(1) before calling collectToArrow
for (x <- rdd.toLocalIterator) {
f(x)
}
} else {
// this is a nice way to process partitions as they are arriving to the driver, stolen from how PySpark makes Arrow batches
rdd.sparkContext.runJob(
rdd,
(it: Iterator[T]) => it.toArray,
(_, xs: Array[T]) => {
for (x <- xs) {
f(x)
}
})
}
}

/** Based on Spark's toArrowBatchRdd */
def collectToArrow2[R](df: Dataset[R], outputFilePath: String, preserverOrder: Boolean): Unit = {
val batchRdd = df.toArrowBatchRdd

try {
Utils.tryWithResource(FileChannel.open(Paths.get(outputFilePath), StandardOpenOption.WRITE)) { outputFilePath =>
val writeChannel = new WriteChannel(outputFilePath)
writeArrowSchema(df.schema, df.sqlContext.conf.sessionLocalTimeZone, writeChannel)

iterateRdd(batchRdd, preserverOrder, (batch: Array[Byte]) => writeChannel.write(batch))
}
} catch {
case x: Throwable =>
x.printStackTrace()
throw x
}
}

private def readArrowSchema(inputFilePath: String): Schema = {
Utils.tryWithResource(FileChannel.open(Paths.get(inputFilePath))) { file =>
readArrowSchema(file)
}
}
private def readArrowSchema(input: java.nio.channels.ReadableByteChannel): Schema = {
val allocator = ArrowUtils.rootAllocator.newChildAllocator("Julia readArrowSchema", 0, Long.MaxValue)
val reader = new ArrowStreamReader(input, allocator)
try {
reader.getVectorSchemaRoot.getSchema
} finally {
reader.close()
allocator.close()
}
}

/** Based on ArrowConverters.toDataFrame */
def fromArrow1(sess: SparkSession, inputFilePath: String): DataFrame = {
val schema = ArrowUtils.fromArrowSchema(readArrowSchema(inputFilePath))
val batches = ArrowConverters.readArrowStreamFromFile(sess.sqlContext, inputFilePath)
ArrowConverters.toDataFrame(batches, schema.json, sess.sqlContext)
}

/** Based on ArrowConverters.fromBatchIterator, creates a LocalRelation which is nice for smaller tables - it should enable filter pushdown on JOINs with this relation. */
def fromArrow2(sess: SparkSession, inputFilePath: String): DataFrame = {
val timeZone = sess.sessionState.conf.sessionLocalTimeZone

val taskContext = TaskContext.empty()
try {
val arrowSchema = readArrowSchema(inputFilePath)
val schema = ArrowUtils.fromArrowSchema(arrowSchema)
val rows = Utils.tryWithResource(FileChannel.open(Paths.get(inputFilePath))) { tempFile =>
ArrowConverters.fromBatchIterator(ArrowConverters.getBatchesFromStream(tempFile), schema, timeZone, taskContext).map(cloneRow(schema)).toArray
}
val relation = LocalRelation(schema.toAttributes, rows, isStreaming = false)
Dataset.ofRows(sess, relation)
} finally {
taskContext.markTaskCompleted(None)
}
}
private def cloneRow(
schema: StructType
): InternalRow => InternalRow = {
// ColumnarBatchRow.copy is buggy - it cannot handle nested objects, arrays and so on
val projection = UnsafeProjection.create(schema)
(r: InternalRow) => projection.apply(r).copy()
}
}
54 changes: 54 additions & 0 deletions src/dataframe.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@
# DataFrame #
###############################################################################


import Arrow
import DataFrames
import Tables

Base.show(df::DataFrame) = jcall(df.jdf, "show", Nothing, ())
Base.show(df::DataFrame, n::Integer) = jcall(df.jdf, "show", Nothing, (jint,), n)
function Base.show(io::IO, df::DataFrame)
Expand All @@ -19,6 +24,12 @@ function Base.getindex(df::DataFrame, name::String)
return Column(jcol)
end

function Base.propertynames(df::DataFrame, private::Bool=false)
columns = Symbol.(columns(df))
properties = [ :jdf, :printSchema, :show, :count, :first, :head, :collect, :collect_tuples, :collect_df, :collect_arrow, :take, :describe, :alias, :select, :withColumn, :filter, :where, :groupBy, :min, :max, :count, :sum, :mean, :minimum, :maximum, :avg, :createOrReplaceTempView, :isStreaming, :writeStream, :write ]
return vcat(columns, properties)
end

function Base.getproperty(df::DataFrame, prop::Symbol)
if hasfield(DataFrame, prop)
return getfield(df, prop)
Expand Down Expand Up @@ -58,6 +69,45 @@ function Base.collect(df::DataFrame, col::Union{<:AbstractString, <:Integer})
return [row[col] for row in rows]
end

"""
Returns an array of named tuples that contains all rows in this DataFrame.
```
julia> spark.sql("select 1 as a, 'x' as b, array(1, 2, 3) as c").collect_tuples()
1-element Vector{NamedTuple{(:a, :b, :c), Tuple{Int32, String, Vector{Int32}}}}:
(a = 1, b = "x", c = [1, 2, 3])
```
"""
function collect_tuples(ds::DataFrame)
Tables.rowtable(collect_arrow(ds))
end

"""
Returns a DataFrame from DataFrames.jl that contains all rows in this Spark DataFrame.
```
julia> spark.sql("select 1 as a, 'x' as b, array(1, 2, 3) as c").collect_df()

1×3 DataFrame
Row │ a b c
│ Int32 String Array…
─────┼───────────────────────────────
1 │ 1 x Int32[1, 2, 3]

```
"""
function collect_df(ds::DataFrame)
DataFrames.DataFrame(collect_arrow(ds))
end


"""Returns an Arrow.Table that contains all rows in this Dataset.
This function will be slightly faster than collect_to_dataframe."""
function collect_arrow(ds::DataFrame)
mktemp() do path,io
jcall(JDatasetUtils, "collectToArrow2", Nothing, (JDataset, JString, jboolean), ds.jdf, path, false)
Arrow.Table(path)
end
end

function take(df::DataFrame, n::Integer)
return convert(Vector{Row}, jcall(df.jdf, "take", JObject, (jint,), n))
end
Expand Down Expand Up @@ -147,6 +197,10 @@ end
###############################################################################

@chainable GroupedData
function Base.propertynames(gdf::GroupedData, private::Bool=false)
[:show, :agg, :min, :max, :sum, :mean, :minimum, :maximum, :avg, :count]
end

function Base.show(io::IO, gdf::GroupedData)
repr = jcall(gdf.jgdf, "toString", JString, ())
repr = replace(repr, "RelationalGroupedDataset" => "GroupedData")
Expand Down
3 changes: 2 additions & 1 deletion src/defs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ const JDataStreamReader = @jimport org.apache.spark.sql.streaming.DataStreamRead
const JDataStreamWriter = @jimport org.apache.spark.sql.streaming.DataStreamWriter
const JStreamingQuery = @jimport org.apache.spark.sql.streaming.StreamingQuery
const JDataset = @jimport org.apache.spark.sql.Dataset
const JDatasetUtils = @jimport org.apache.spark.sql.julia.DatasetUtils
const JRelationalGroupedDataset = @jimport org.apache.spark.sql.RelationalGroupedDataset

# const JRowFactory = @jimport org.apache.spark.sql.RowFactory
Expand Down Expand Up @@ -149,4 +150,4 @@ end
"A handle to a query that is executing continuously in the background as new data arrives"
struct StreamingQuery
jquery::JStreamingQuery
end
end
31 changes: 29 additions & 2 deletions src/session.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@
# SparkSession.Builder #
###############################################################################

import Tables
import TableTraits
import DataFrames
import Arrow

@chainable SparkSessionBuilder
Base.show(io::IO, ::SparkSessionBuilder) = print(io, "SparkSessionBuilder()")

Expand Down Expand Up @@ -38,7 +43,7 @@ end
###############################################################################
# SparkSession #
###############################################################################

Base.propertynames(::SparkSession, private::Bool=false) = [:version, :stop, :conf, :createDataFrame, :createDataFrameFromTable, :sql]
@chainable SparkSession
Base.show(io::IO, ::SparkSession) = print(io, "SparkSession()")

Expand Down Expand Up @@ -101,6 +106,28 @@ function createDataFrame(spark::SparkSession, rows::Vector{Row})
return spark.createDataFrame(rows, st)
end

"Creates Spark DataFrame from the Julia table using Arrow.jl for data transfer."
function createDataFrame(spark::SparkSession, data::Tables.AbstractColumns)
createDataFrameFromTable(spark, data)
end
"Creates Spark DataFrame from the Julia table using Arrow.jl for data transfer."
function createDataFrame(spark::SparkSession, data::DataFrames.AbstractDataFrame)
createDataFrameFromTable(spark, data)
end

"Creates Spark DataFrame from any Tables.jl compatible table. Uses Arrow.jl for data transfer. When localRelation=true a LocalRelation is creates and Spark should be able to perform filter pushdown on JOINs with this DataFrame"
function createDataFrameFromTable(spark::SparkSession, table, localRelation=false)
mktemp() do path,io
Arrow.write(path, table; file=false)
fn = if localRelation
"fromArrow2"
else
"fromArrow1"
end
jdf = jcall(JDatasetUtils, fn, JDataset, (JSparkSession,JString), spark.jspark, path)
DataFrame(jdf)
end
end

function sql(spark::SparkSession, query::String)
jdf = jcall(spark.jspark, "sql", JDataset, (JString,), query)
Expand Down Expand Up @@ -145,4 +172,4 @@ for JT in (JString, jlong, jboolean)
@eval function set(cnf::RuntimeConfig, key::String, value::$T)
jcall(cnf.jconf, "set", Nothing, (JString, $JT), key, value)
end
end
end
3 changes: 3 additions & 0 deletions test/data/nestedStructures.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
{ "name": "Peter", "pets": [{ "name": "Albert", "species": "dog" }] }
{ "name": "Thomas", "pets": [] }
{ "name": "Peter", "pets": [{ "name": null, "species": "mouse" }, { "name": null, "species": "mouse" }, { "name": null, "species": "mouse" }] }
3 changes: 2 additions & 1 deletion test/data/people.json
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
[{"name": "Peter", "age": 32}, {"name": "Belle", "age": 27}]
{"name": "Peter", "age": 32}
{"name": "Belle", "age": 27}
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ include("test_chainable.jl")
include("test_convert.jl")
include("test_compiler.jl")
include("test_sql.jl")
include("test_arrow.jl")

spark.stop()

Expand Down
24 changes: 24 additions & 0 deletions test/test_arrow.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
using Spark
using Spark.Compiler
import DataFrames

data_dir = joinpath(@__DIR__, "data")

@testset "collect_df" begin
numbers = spark.sql("SELECT * FROM range(-3, 0)").collect_df()
@test numbers == DataFrames.DataFrame(id = [-3, -2, -1])

people = spark.read.json(joinpath(data_dir, "people.json")).collect_df()

@test isequal(people.name, ["Peter", "Belle"])

nested = spark.read.json(joinpath(data_dir, "nestedStructures.json")).collect_df()
pet_names = [ p.name for p in nested.pets |> Iterators.flatten |> collect ]
@test isequal(pet_names, [ "Albert", missing, missing, missing ])
end

@testset "createDataFrame" begin
nested = spark.read.json(joinpath(data_dir, "nestedStructures.json")).collect_df()
spark_df = spark.createDataFrame(nested)
@test isequal(spark_df.collect_df(), nested)
end