Skip to content

Commit

Permalink
Add Scala Spark example doc page (com-lihaoyi#4513)
Browse files Browse the repository at this point in the history
This pr implements the examples for spark

Closes [Issue 4453](com-lihaoyi#4453)

Checklist:
- [x] **example/scalalib/spark**
     - [x]  1-hello-saprk
     - [x]  2-hello-pyspark
     - [x]  3-semi-realistic (+ spark-submit)
  • Loading branch information
monyedavid authored Feb 10, 2025
1 parent 3a47816 commit a40b5b0
Show file tree
Hide file tree
Showing 14 changed files with 426 additions and 0 deletions.
1 change: 1 addition & 0 deletions example/package.mill
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ object `package` extends RootModule with Module {
object publishing extends Cross[ExampleCrossModule](build.listIn(millSourcePath / "publishing"))
object web extends Cross[ExampleCrossModule](build.listIn(millSourcePath / "web"))
object native extends Cross[ExampleCrossModule](build.listIn(millSourcePath / "native"))
object spark extends Cross[ExampleCrossModule](build.listIn(millSourcePath / "spark"))
}
object javascriptlib extends Module {
object basic extends Cross[ExampleCrossModule](build.listIn(millSourcePath / "basic"))
Expand Down
38 changes: 38 additions & 0 deletions example/scalalib/spark/1-hello-spark/build.mill
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
package build
import mill._, scalalib._

object foo extends ScalaModule {
def scalaVersion = "2.12.15"
def ivyDeps = Agg(
ivy"org.apache.spark::spark-core:3.5.4",
ivy"org.apache.spark::spark-sql:3.5.4"
)

def forkArgs = Seq("--add-opens", "java.base/sun.nio.ch=ALL-UNNAMED")

object test extends ScalaTests {
def ivyDeps = Agg(ivy"com.lihaoyi::utest:0.8.5")
def testFramework = "utest.runner.Framework"

def forkArgs = Seq("--add-opens", "java.base/sun.nio.ch=ALL-UNNAMED")
}

}
// This examples demonstrates running spark using mill.

/** Usage

> ./mill foo.run
...
+-------------+
| message|
+-------------+
|Hello, World!|
+-------------+
...

> ./mill foo.test
...
+ foo.FooTests.helloWorld should create a DataFrame with one row containing 'Hello, World!'...
...
*/
23 changes: 23 additions & 0 deletions example/scalalib/spark/1-hello-spark/foo/src/foo/Foo.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package foo

import org.apache.spark.sql.{DataFrame, SparkSession}

object Foo {

def helloWorld(spark: SparkSession): DataFrame = {
val data = Seq("Hello, World!")
val df = spark.createDataFrame(data.map(Tuple1(_))).toDF("message")
df
}

def main(args: Array[String]): Unit = {
val spark = SparkSession.builder()
.appName("HelloWorld")
.master("local[*]")
.getOrCreate()

helloWorld(spark).show()

spark.stop()
}
}
22 changes: 22 additions & 0 deletions example/scalalib/spark/1-hello-spark/foo/test/src/FooTests.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
package foo

import org.apache.spark.sql.SparkSession
import utest._

object FooTests extends TestSuite {
def tests = Tests {
test("helloWorld should create a DataFrame with one row containing 'Hello, World!'") {
val spark = SparkSession.builder()
.appName("HelloWorldTest")
.master("local[*]")
.getOrCreate()

val df = Foo.helloWorld(spark)
val messages = df.collect().map(_.getString(0)).toList
assert(messages == List("Hello, World!"))

// Stop the SparkSession
spark.stop()
}
}
}
32 changes: 32 additions & 0 deletions example/scalalib/spark/2-hello-pyspark/build.mill
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
package build
import mill._, pythonlib._

object foo extends PythonModule {

def mainScript = Task.Source { millSourcePath / "src" / "foo.py" }
def pythonDeps = Seq("pyspark==3.5.4")

object test extends PythonTests with TestModule.Unittest

}

/** Usage

> ./mill foo.run
...
+-------------+
| message|
+-------------+
|Hello, World!|
+-------------+
...

> ./mill foo.test
...
test_hello_world...
...
Ran 1 test...
...
OK
...
*/
19 changes: 19 additions & 0 deletions example/scalalib/spark/2-hello-pyspark/foo/src/foo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from pyspark.sql import SparkSession, DataFrame

def hello_world(spark: SparkSession) -> DataFrame:
data = [("Hello, World!",)]
df = spark.createDataFrame(data, ["message"])
return df

def main():
spark = SparkSession.builder \
.appName("HelloWorld") \
.master("local[*]") \
.getOrCreate()

hello_world(spark).show()

spark.stop()

if __name__ == "__main__":
main()
23 changes: 23 additions & 0 deletions example/scalalib/spark/2-hello-pyspark/foo/test/src/test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import unittest
from pyspark.sql import SparkSession
from foo import hello_world

class HelloWorldTest(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.spark = SparkSession.builder \
.appName("HelloWorldTest") \
.master("local[*]") \
.getOrCreate()

@classmethod
def tearDownClass(cls):
cls.spark.stop()

def test_hello_world(self):
df = hello_world(self.spark)
messages = [row['message'] for row in df.collect()]
self.assertEqual(messages, ["Hello, World!"])

if __name__ == "__main__":
unittest.main()
62 changes: 62 additions & 0 deletions example/scalalib/spark/3-semi-realistic/build.mill
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
package build
import mill._, scalalib._

object `package` extends RootModule with ScalaModule {
def scalaVersion = "2.12.15"
def ivyDeps = Agg(
ivy"org.apache.spark::spark-core:3.5.4",
ivy"org.apache.spark::spark-sql:3.5.4"
)

def forkArgs = Seq("--add-opens", "java.base/sun.nio.ch=ALL-UNNAMED")

def prependShellScript = ""

object test extends ScalaTests {
def ivyDeps = Agg(ivy"com.lihaoyi::utest:0.8.5")
def testFramework = "utest.runner.Framework"

def forkArgs = Seq("--add-opens", "java.base/sun.nio.ch=ALL-UNNAMED")
}

}

// This examples demonstrates a semi realistic example calculating summary statics
// from a transactions.csv passed in as argument, defaulting to resources if not present.

/** Usage

> ./mill run
...
Summary Statistics by Category:
+-----------+------------+--------------+-----------------+
| category|total_amount|average_amount|transaction_count|
+-----------+------------+--------------+-----------------+
| Food| 70.5| 23.5| 3|
|Electronics| 375.0| 187.5| 2|
| Clothing| 120.5| 60.25| 2|
+-----------+------------+--------------+-----------------+
...

> ./mill test
...
+ foo.FooTests.computeSummary should compute correct summary statistics...
...

> chmod +x spark-submit.sh

> ./mill show assembly # prepare for spark-submit
".../out/assembly.dest/out.jar"

> ./spark-submit.sh out/assembly.dest/out.jar foo.Foo resources/transactions.csv
...
Summary Statistics by Category:
+-----------+------------+--------------+-----------------+
| category|total_amount|average_amount|transaction_count|
+-----------+------------+--------------+-----------------+
| Food| 70.5| 23.5| 3|
|Electronics| 375.0| 187.5| 2|
| Clothing| 120.5| 60.25| 2|
+-----------+------------+--------------+-----------------+
...
*/
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
id,category,amount
1,Food,20.5
2,Electronics,250.0
3,Food,35.0
4,Clothing,45.5
5,Food,15.0
6,Electronics,125.0
7,Clothing,75.0
67 changes: 67 additions & 0 deletions example/scalalib/spark/3-semi-realistic/spark-submit.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
#!/bin/bash

# Check if at least 2 arguments are provided
if [ "$#" -lt 2 ]; then
echo "Usage: $0 path/to/your-module-assembly.jar fully.qualified.MainClass [path/to/resource.csv]"
exit 1
fi

# The first argument is the JAR path, the second is the main class
JAR_PATH="$1"
MAIN_CLASS="$2"
MASTER="local[*]"

# Shift out the first two arguments so that any remaining ones (like a resource argument) are forwarded
shift 2

# Function to install Apache Spark via Homebrew (macOS)
install_spark_brew() {
echo "Installing Apache Spark via Homebrew..."
brew update && brew install apache-spark
}

# Function to download and extract Apache Spark manually
install_spark_manual() {
SPARK_VERSION="3.3.0"
HADOOP_VERSION="3"
SPARK_PACKAGE="spark-${SPARK_VERSION}-bin-hadoop${HADOOP_VERSION}"
DOWNLOAD_URL="https://archive.apache.org/dist/spark/spark-${SPARK_VERSION}/${SPARK_PACKAGE}.tgz"
INSTALL_DIR="$HOME/spark"

mkdir -p "$INSTALL_DIR"
echo "Downloading Apache Spark from $DOWNLOAD_URL..."
# Use -fL to fail on HTTP errors and follow redirects.
curl -fL "$DOWNLOAD_URL" -o "$INSTALL_DIR/${SPARK_PACKAGE}.tgz" || { echo "Download failed."; exit 1; }

echo "Extracting Apache Spark..."
tar -xzf "$INSTALL_DIR/${SPARK_PACKAGE}.tgz" -C "$INSTALL_DIR" || { echo "Extraction failed."; exit 1; }

# Set SPARK_HOME and update PATH
export SPARK_HOME="$INSTALL_DIR/${SPARK_PACKAGE}"
export PATH="$SPARK_HOME/bin:$PATH"
}

# Check if spark-submit is installed
if ! command -v spark-submit &> /dev/null; then
echo "spark-submit not found. Installing Apache Spark..."
if command -v brew &> /dev/null; then
install_spark_brew
else
install_spark_manual
fi
fi

# Verify installation
if ! command -v spark-submit &> /dev/null; then
echo "spark-submit is still not available. Exiting."
exit 1
fi

echo "spark-submit is installed. Running the Spark application..."

# Run spark-submit, forwarding any additional arguments (e.g., resource path) to the application
spark-submit \
--class "$MAIN_CLASS" \
--master "$MASTER" \
"$JAR_PATH" \
"$@"
47 changes: 47 additions & 0 deletions example/scalalib/spark/3-semi-realistic/src/foo/Foo.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
package foo

import org.apache.spark.sql.{SparkSession, Dataset, DataFrame}
import org.apache.spark.sql.functions._

object Foo {

case class Transaction(id: Int, category: String, amount: Double)

def computeSummary(transactions: Dataset[Transaction]): DataFrame = {
transactions.groupBy("category")
.agg(
sum("amount").alias("total_amount"),
avg("amount").alias("average_amount"),
count("amount").alias("transaction_count")
)
}

def main(args: Array[String]): Unit = {
val spark = SparkSession.builder()
.appName("SparkExample")
.master("local[*]")
.getOrCreate()

// Check for a file path provided as a command-line argument first;
// otherwise, use resources.
val resourcePath: String = args.headOption
.orElse(Option(getClass.getResource("/transactions.csv")).map(_.getPath))
.getOrElse(throw new RuntimeException(
"transactions.csv not provided as argument and not found in resources"
))

import spark.implicits._

val df = spark.read
.option("header", "true")
.option("inferSchema", "true")
.csv(resourcePath)

val transactionsDS: Dataset[Transaction] = df.as[Transaction]
val summaryDF = computeSummary(transactionsDS)

println("Summary Statistics by Category:")
summaryDF.show()
spark.stop()
}
}
Loading

0 comments on commit a40b5b0

Please sign in to comment.