UDFs

User Defined Functions

There are probably several ways to define and use a UDF in PySpark. Below is the fastest way that worked for me, using a decorator.

# sample data frame
df = spark.createDataFrame([(1, "John Doe", 21, 2019, 2), (2, "Jane Doe", 24, 2021, 3)], ("id", "name", "age", "year","month"))
df.show()

import pyspark.sql.functions as f


@udf(returnType=IntegerType()) #or `f.udf` if functions are imported as f, as here
def month_length(y:int, m:int):
    import calendar
    value = calendar.monthrange(y, m)[1]
    return value
    
df = df.withColumn("days_in_month", month_length("year","month") )

☞ Notes

  • A decorator is a way to pass a function as an argument into another function. Search online for Python decorators if you want more details.

  • Details about calendar.monthrange() function.

  • When calling your function in .withColumn() , enter column names in the same order the parameters defined in the function. Including parameter name while calling the function failed for me.

  • Example is modified from spark 3.1 UDF official page.

  • Remember your function input parameters are columns in the data frame. Make sure column type is accounted for when you write your function.

  • Unlike Scala, adding the type of the input variables isn't needed, and it doesn't affect the behavior of the function in Python, I like to do it just to keep track of what's what.

Using a UDF to Perform Element-Wise Operations on ArrayColumn Type

Say you want to compare two ArrayType() columns in your data frame, or test whether each element satisfies a certain condition, or anything else you want to do to each element in an ArrayType column. How do you do that?

Create a UDF that takes in the specific column type you want to operate on, return the wanted type. For example, if you want to test whether or not each element in an ArrayType column meets a certain condition, you might want to return an ArrayType column of BooleanType i.e. an array of true or false. Everything else in defining the UDF is the same as we did before. Here is an example,

Creating the toy data frame

from pyspark.sql import Row

# create the tuple of Row objects (the RDD)
arrayStructureData = (Row("James,,Smith", [2,3,20], ["Spark","Java"], "OH", "CA", 123, 456.78, 0.1),
Row("Michael,Rose,", [4,6,3], ["Spark","Java"], "NY", "NJ", 75, 234.01, 0.2),
Row("Robert,,Williams", [10,15,6], ["Spark","Python"], "UT", "NV", 82, 987.02, 0.7),
Row("John,,Doe", [20,25,62], ["C++","Python"], "TN", "TN", 98, 332.30, 0.9),
Row("Jane,,Doe", [50,55,65], ["Spark","C++"], "TX", "MN", 61, 980.23, 0.8),
Row("Jack,,Smith", [11,34,98], ["JavaScript","Go"], "CA", "MI", 110, 937.94, 0.5),
Row("Jillan,,Bernard", [2,1,9], ["R","Python"], "CT", "NY", 132, 128.95, 0.6),
Row("Phillip,,Kraft", [1,13,0], ["Python","Stata"], "RI", "FL", 95, 563.63, 0.1),
Row("Karl,,Lund", [74,92,14], ["Go","Python"], "CA", "WA", 55, 614.84, 0.4)
)

# define the column names
COLUMNS = ["name", "arr1", "lang", "state0", "state1", "num1", "num2", "num3"]

# make them into a data frame
dummy_df = spark.createDataFrame(arrayStructureData, COLUMNS)

Now mind the type of colums, specifically, the type of elements in each ArrayType column,

>>> dummy_df.printSchema()
root
 |-- name: string (nullable = true)
 |-- arr1: array (nullable = true)
 |    |-- element: long (containsNull = true)
 |-- lang: array (nullable = true)
 |    |-- element: string (containsNull = true)
 |-- state0: string (nullable = true)
 |-- state1: string (nullable = true)
 |-- num1: long (nullable = true)
 |-- num2: double (nullable = true)
 |-- num3: double (nullable = true)

Let's work with column arr1 to test whether each element is greather than 45 or not

from pyspark.sql import functions as f
from pyspark.sql.types import *

@f.udf(returnType=ArrayType(BooleanType())) #`f.udf` coz functions are imported as f here
def test_elements(num_arr):
    return [item > 45 for item in num_arr]
    
dummy_df\
.withColumn("greater_than_45", test_elements("arr1") )\
.select("arr1", "greater_than_45")\
.show(10,truncate=False)

#if you see results work as intended, update the modified data frame
dummy_df= dummy_df\
.withColumn("greater_than_45", test_elements("arr1") )\

You can pretty much do whatever to array columns that you usually do in Python lists. The beautiful part is, you don't need to convert anything to Pandas or another native Python object. All you have to do is to convert your work to functions, then pass that to UDF with a decorator; then use the column names as strings in the udf argument when you call it. That is, you don't need to use f.col("colname") in fact, that will give you an error.

Notes

  • How come we didn't use f.col("arr1") for the UDF argument(s)? Since we defined our function as a Spark User Defined Function, every argument you pass is a Column type by default. In fact, if you do use f.col("arr1") you will a vague error about serialization

  • To design a UDF, think about element type in each row in the column you're passing to your UDF, because it reads it sort of like Python object. e.g. if you have an array column, then it's read as an array object in Python, a list.

  • Be very mindful of your types. e.g. if you define the returnType=FloatType() then the result of your function returns a Numpy float, you'll get a confusing error, in Spark fashion. You must then convert it to Python's float to resolve this issue. i.e. in the last return statement in your function, do return float(result) to convert the Numpy float to Python's float.

  • For a sloppy job, you can skip defining a returnType altogether, while passing the function to UDF. However, defining the return type makes sure the output column is what you need it to be, in case you need further transformations or computations on it.

  • Just for kicks, the error I got for Numpy float error while defining returnType=FloatType() was, Job aborted due to stage failure ... net.razorvine.pickle.PickleException: expected zero arguments for construction of ClassDict (for numpy.dtype)

Optimizing PySpark UDFs

Unlike Scala UDFs, Python's UDFs are slow, that's because they need to be serialized and deserialized, thus slowing down your transformation by a lot. Legacy Python UDFs like the ones discussed above is what I'm talking about. Now in Spark 3 they introduced Pandas UDFs, which works on a bunch of rows at once -a.k.a. vectors, rather than one row at a time like a legacy UDFs. Here's the best blog about how to write good Pandas UDFs https://databricks.com/blog/2020/05/20/new-pandas-udfs-and-python-type-hints-in-the-upcoming-release-of-apache-spark-3-0.html

To use Pandas UDFs, just replace the decorator @udf with @pandas_udf . If that didn't take, then try @f.pandas_udf . Where we earlier imported functions as "f" import pyspark.sql.functions as f Just keep in mind to do checks with .display(), .show(), or .count(), on the dataframe after the UDF application to make sure it executed successfully. That is because those commands forces Spark to execute, thus revealing any hidden errors. Recall that Spark is lazily evaluated, that is, it doesn't actually execute commands until an action is required, like those three aforementioned.

☞ If you have a "ValueError: truth value ambiguous..." type of error that we see with Pandas datafarmes sometimes when we try to filter or apply a function, just replace f.pandas_udf with f.udf thus going back to the legacy PySpark UDF.

💡Whever you can avoid using PySpark's UDFs altogether, do it. Many times, breaking down the process to operations on columns, by creating new columns or using Window functions, can do what you're trying to do more efficiently. You can always delete those columns you created in the process to keep your dataframes clean and data processing to minimum necessary.

☞ Further resources

There's a lot of great articles in DataBricks blogs, https://databricks.com/blog/category/engineering they even have free self-paced courses. However, not all DataBricks features are available in open-source Apache Spark. e.g. z-ordering on Delta files, or dbutils; or some of the best optimization configurations automatically enabled on the cluster like Adaptive Query Execution (settingspark.sql.adaptive.enable to True).

Last updated