Full Worked Random Forest Classifier Example

Code Snippets for RandomForestClassifier - PySpark

The general plan is

  1. Extract all string columns; to encode them afterwards with StringIndexer

  2. Extract all numerical columns to impute nulls; if the model complained while fitting.

  3. Create a Pipeline for all the steps you want to do. e.g. StringIndexer, Imputer, OneHotEncoder, StandardScaler (though is Standardizing isn't needed in RandomForest), VectorAssembler (to create the "features" column). Don't forget to leave out the target variable, which has to be in binary 0/1 form, or integer for multiclass, not strings.

  4. Fit the Pipeline and transform the dataframe.

  5. Split the dataframe into training and testing datasets, using randomSplit . Don't forget to set the seed argument in randomsplit such that you get the same results later on if you ran the same experiment again.

  6. Instantiate RandomForestClassifier model. At this point don't include the "features" column name in the arguments as the featuresCol, we will add it to it later. For some reason it errored out when I included it while instantiating the model. Don't forget to set the seed also here, for reproducibility.

  7. setFeaturesCol on your instantiated RandomForestClassifier model.

  8. Fit model to training set.

  9. Save the fitted model for later use if needed. (and how to load it afterwards)

  10. Predict on holdout set. Save dataframe with those predictions just in case.

  11. Instantiate the BinaryClassificationEvaluator or the MulticlassClassificationEvaluator depending on what you have.

  12. Call the evaluator on your predictions column.

  13. Create the Confusion Matrix, with a workaround

  14. Create the Feature Importance plot, with a workaround.

The Imports

import pyspark.sql.functions as f
from pyspark.sql.types import *

from pyspark.ml.pipeline import Pipeline
from pyspark.ml.feature import OneHotEncoder, StandardScaler, VectorAssembler, StringIndexer, Imputer

from pyspark.ml.classification import RandomForestClassifier
from pyspark.ml.evaluation import BinaryClassificationEvaluator

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

Extract String and Numerical Columns

Other than printSchema() , you can see the types of all the columns in a PySpark dataframe with the command df.dtypes which will return tuples, of column name, and its type. We will use this next.

You can choose to compute whatever differences and calculations you want from your TimestampType and DateType columns, afterwhich you might want to drop them from your dataframe.

Then you want to make sure the ID column(s) and the target variable are execluded from the predictor columns,

Now we're ready to have those StringType columns in one list, and all numerical predictors in another list. Make sure, you're selecting only the predictors in the following step!

☞ Note

You can count how many columns you have of each type with the following snippet,

This will give you a great overview of what you got. For example, if you have a "decimal" type column, you might want to convert them back to "double". I know if you want to convert the dataframe to Pandas at any point, it will complain about the DecimalType columns as not efficient.

Creating a Pipeline to Do All Feature Engineering

☞ Note

I'm against imputing null values with median. A good data scientist should look deeper into the data and understand the domain in order to replace missing values, or find another way to handle them. Yet, I'm listing the Imputer here for a quick and dirty analysis if you must.

If you want to learn more about the argument and what they mean, read the official docs. I find 3.1 docs for PySpark are a lot easier to read.

Fit the Pipeline and Transform the Data

Split the DataFrame to Train and Test

Instantiate RandomForestClassifier

Read the docs for the rest of the arguments.

Set Features Column on the Instantiated RF Classifier

Fit to Training Set

Save the Fitted Model

Loading the Fitted Model in a Later Notebook

Predict on Holdout Set

Very important and awesome thing about Spark, the predictions are columns added to the original dataframe, so you don't lose anything, and you don't need to merge back to know who's prediction is this.

Save the Dataframe with Predictions, Just in Case

Scoring (Evaluating) RFC Model

  • In PySpark, when predicting with a classifier, you'll get 3 columns: predictionCol, probabilityCol and rawPredictionCol .

    • The first one is the 1/0 of your binary classification,

    • The second one is the equivalent of predict proba in Scikit-Learn i.e. the probabilities of predicting positive or negative class, with the same defaults 50-50 that is, if >0.5 it's class 1 and if <0.5 it's labeled class 0.

    • The last column is for Spark's internal use. It's the column you feed to predictor and evaluator etc.

    • If you really want to, you can change those default namings while instantiating the classifier. In our example, the RandomForestClassifier has arguments to change the names of those columns.

  • There are two metrics for BinaryClassificationEvaluator, "areaUnderROC" which is the default, and "areadUnderPR"

Creating the Confusion Matrix, with a Workaround

I chose to go fancy and show the Confusion Matrix as Pandas dataframe, but the print statement I commented out is enough to show the CM; toArray is a method available only on pyspark.ml.linalg.Vector column, and it converts this Vector to a Numpy array.

You can subsititute the words "positive" and "negative" with whatver your classes actually represent, for cleaner and easier to understand presentation.

Source There are a couple of resources, this is the onearrow-up-right I got the heart of this code from. There is anotherarrow-up-right I didn't follow because I doubted the features names are the correct order with respect to features importances; my dataframe has several hundred columns.

Plot Top 20 Most Important Features, According to RandomForestClassifier

☞ Notes

To quickly convert all DecimalType columns to DoubleType columns in PySpark,

Last updated