This short article shows how to use Python user-defined functions in PySpark applications. To use a UDF, we need to do some basic tasks:
- Create a UDF (user-defined-function) in Python
- Register UDF
- Use UDF in Spark SQL
User-Defined Functions (UDFs) are user-programmable functions that act on one row. Spark UDF (a.k.a User Defined Function) is the useful feature of Spark SQL & DataFrame which extends the Spark built in capabilities. UDF’s are used to extend the functions of the Spark framework and re-use this function on several DataFrame.
Consider a function which triples its input:
# n : integer
def tripled(n):
return 3 * n
To register a UDF, we can use SparkSession.udf.register()
The register()
function takes 3 parameters:
- 1st: the desired name for UDF to be used in SQL
- 2nd: the name of Python UDF function
- 3rd: the return data type of Python UDF function (if this parameter is missing, then it is assumed that it is
# "tripled_udf" : desired name to use in SQL
# tripled : defined Python function
# the last argument is the return type of UDF function
from pyspark.sql.types import IntegerType
spark.udf.register("tripled_udf", tripled, IntegerType())
Now, lets create a DataFrame and then apply the created UDF.
Create a sample DataFrame:
>>> data = [('alex', 20, 12000), ('jane', 30, 45000),
('rafa', 40, 56000), ('ted', 30, 145000),
('xo2', 10, 1332000), ('mary', 44, 555000)]
>>> column_names = ['name', 'age', 'salary']
>>> df = spark.createDataFrame(data, column_names)
>>> df
DataFrame[name: string, age: bigint, salary: bigint]
>>> df.printSchema()
|-- name: string (nullable = true)
|-- age: long (nullable = true)
|-- salary: long (nullable = true)
>>> df.show()
|name|age| salary|
|alex| 20| 12000|
|jane| 30| 45000|
|rafa| 40| 56000|
| ted| 30| 145000|
| xo2| 10|1332000|
|mary| 44| 555000|
>>> df.count()
>>> df2 = spark.sql("select * from people where salary > 67000")
>>> df2.show()
|name|age| salary|
| ted| 30| 145000|
| xo2| 10|1332000|
|mary| 44| 555000|
>>> df.createOrReplaceTempView("people")
>>> df2 = spark.sql("select name, age, salary, tripled_udf(salary) as tripled_salary from people")
>>> df2.show()
|name|age| salary|tripled_salary|
|alex| 20| 12000| 36000|
|jane| 30| 45000| 135000|
|rafa| 40| 56000| 168000|
| ted| 30| 145000| 435000|
| xo2| 10|1332000| 3996000|
|mary| 44| 555000| 1665000|