Introduction to Machine Learning with Spark and MLlib (DataFrame API)

Introduction to Machine Learning with Spark and MLlib (DataFrame API)

Introduction to Machine Learning with Spark and MLlib (DataFrame API)

A pretty hot topic lately is machine learning – the inter-sectional discipline closely related to computational statistics that lets computers learn without being explicitly programmed.

It has been found to be of significant use in the field of data analytics – from estimating loan and insurance risk to trying to autonomously steer a car in real-life conditions.

In the following post, I would like to introduce to the reader MLlib – a machine learning library that is part of the Spark Framework.

One important thing about the following text – the aim is to introduce the library, not the concept and theory behind machine learning or statistics in general so a basic understanding of these topics is a necessity for the reader, as well as basic knowledge of Spark.

This will be based on Apache Spark 2.x API which employs the new DataFrame API as an alternative to the older RDD one.

One of the main benefits of the DataFrame approach is that it’s easier to use than the RDD one and more user-friendly. Still, the RDD API is still present but put into maintenance mode (it will no longer be extended and is going to be deprecated when the DataFrame API reaches feature parity with it).

Introduction to MLlib

MLlib (short for Machine Learning Library) is Apache Spark’s machine learning library that provides us with Spark’s superb scalability and usability if you try to solve machine learning problems. Under the hood, MLlib uses Breeze for its linear algebra needs.

The library consists of a pretty extensive set of features that I will now briefly present. A more in-depth description of each feature set will be provided in further sections.



  • Regression
    • Linear
    • Generalized Linear
    • Decision Tree
    • Random Forest
    • Gradient-boosted Tree
    • Survival
    • Isotonic
  • Classification
    • Logistic (Binomial and Multinomial)
    • Decision Tree
    • Random Forest
    • Gradient-boosted tree
    • Multilayer Perceptron
    • Linear support vector machine
    • One-vs-All
    • Naive Bayes
  • Clustering
    • K-means
    • Latent Dirichlet allocation
    • Bisecting k-means
    • Gaussian Mixture Model
  • Collaborative Filtering


  • Feature extraction
  • Transformation
  • Dimensionality reduction
  • Selection


  • Composing Pipelines
  • Constructing, evaluating and tuning machine learning Pipelines


  • Saving algorithms, models and pipelines to persistent data storage for later use
  • Loading algorithms, models and pipelines from persistent data storage


  • Linear algebra
  • Statistics
  • Data handling
  • Other

The DataFrame

As mentioned before, the DataFrame is the new API employed in Spark versions 2.x that is supposed to be a replacement to the older RDD API. A DataFrame is a Spark Dataset (a distributed, strongly-typed collection of data, the interface was introduced in Spark 1.6) organized into named columns (which represent the variables).

The concept is effectively the same as a table in a relational database or a data frame in R/Python, but with a set of implicit optimizations.


What are the main selling points and benefits of using the DataFrame API over the older RDD one? Here’s a few:

  • Familiarity – as mentioned beforehand, the concept is analogous to wider known and used approaches of manipulating data as tables in relational databases or the data frame construct in e.g. R.
  • Uniform API – the API is consistent among the languages thus we don’t waste time on accommodating the differences and can focus on what’s important.
  • Spark SQL – it enables us to access and manipulate the data via SQL queries and a SQL-like domain-specific language.
  • Optimizations – there is a set of optimizations implemented under the hood of Dataset that give us a better performance with data handling.
  • Plenty of possible sources – we can construct a DataSet from external databases, existing RDDs, CSV files, JSON and a multitude of other structured data sources.

Creating a DataFrame

As mentioned above – we have multiple possible sources using which we can create a DataFrame. To load a streaming Dataset from an external source, we will use the DataStreamReader interface.

In the examples below, we assume a variable named spark exists with the SparkSession. The DataStreamReader for the session can be obtained by calling the read method.

We can add input options for the underlying data source by calling the optionmethod upon the reader instance. It takes a key and a value as the argument (or a whole Map).

There are two approaches to loading the data:

* Format-specific methods like csv, jdbc, etc.

* Specifying the format explicitly with the format method and then calling the generic load method. If no format is set Parquetis the default one.

Here are the most commonly used cases when it comes to creating a DataFrameand the method used:


Parquet is a columnar storage format developed by Apache for projects in the Hadoop/Spark ecosystems.

We load it by calling the load or parquet methods with the path to the Parquet file as the argument, e.g.:"some/path/to/file.parquet")


The well known comma-separated values file. Spark can automatically infer the schema of a CSV file loaded.

We load it by calling the csv method with the path to the CSV file as the argument, e.g.:"some/path/to/file.csv")


The JavaScript Object Notation format most widely utilized by Web applications for asynchronous frontend/backend communication. Spark can automatically infer the schema of a JSON file loaded.

We load it by calling the json method with the path to the JSON file as the argument, e.g.:"some/path/to/file.json")


Apache Hive is a data warehouse software package. For interfacing DataFrames with Hive we need a SparkSession with enabled Hivesupport and all the needed dependencies in the classpath for Spark to load them automatically.

We will not cover interfacing with a Hive data storage as this would require understanding what Hive is and how it works. For more information about the topic please consult the official documentation .


We can easily interface with any kind of database using JDBC. For it to be possible you need to have the required JDBC driver for the database you want to interface with included in Your classpath.

We will use the load method mentioned before the format has to be changed from the default one (Parquet) to jdbc using the format method upon the reader. We can also use the jdbc method and pass it to a Properties class instance that will hold the connection properties.

We specify the JDBC connection properties via the option method mentioned before. A full list of possible options that can be passed and their descriptions are available here.

Here is a quick example of creating a DataFrame from a JDBC. The source could look like this (example from the official documentation):

val jdbcDF =
  .option("url", "jdbc:postgresql:dbserver")
  .option("dbtable", "schema.tablename")
  .option("user", "username")
  .option("password", "password")

Or using the jdbc method:

val connectionProperties = new Properties()
connectionProperties.put("user", "username")
connectionProperties.put("password", "password")
val jdbcDF2 =
  .jdbc("jdbc:postgresql:dbserver", "schema.tablename", connectionProperties)


We can automatically convert a RDD into a DataFrame. The names of the case classes’ arguments will become the column names. It supports nesting complex types like Seq or Array.

All we need to do is simply call the toDF method on the RDD, i.e.:

val dataFrame = someRDD.toDF()

Defining the Schema

The schema of the data can be often inferred automatically but if that option isn’t available for our data or we simply want to define it manually we have three main ways of doing so:


Explicit casting of columns from one type onto another. E.g.:

val dataFrame = otherDataFrame
  .withColumn("numericalColumn", dataFrame("numericalColumn").cast(DoubleType))


Using the StructType and StructField types to explicitly define what DataType is each column. E.g.:

val schemaStruct =
    StructField("intColumn", IntegerType, true) ::
    StructField("longColumn", LongType, true) ::
    StructField("booleanColumn", BooleanType, true) :: Nil)

val df =
  .option("header", true)


This is a concept of Spark SQL’s serialization and deserialization framework. We can use Encoders to provide the schema via a case object.

case class SchemaClass(intColumn: Int, longColumn: Long, booleanColumn: Boolean)

val schemaEncoded = Encoders.product[SchemaClass].schema

val df =
  .option("header", true)

Saving a DataFrame

We can save a DataFrame to persistent storage by using the DataFrameWriter interface that we can obtain from a DataFrame by simply calling the write method.

Writing the DataFrame is almost identical in most cases, we just call the methods mentioned before on write instead of read. E.g. writing a DataFrame to a JSON file:

val dataFrame ="someFile.csv")


Exploring a DataFrame

We have two main methods used in inspecting the contents and structure of a DataFrame (or any other Dataset) – show and printSchema.

The show method comes in five versions:

  • show() – displays the top 20 rows in tabular form.
  • show(numRows: Int) – shows top numRows in a tabular form.
  • show(truncate: Boolean) – shows top 20 rows in a tabular form. If truncate is true then strings that are longer than 20 characters will be truncated and cells become right-aligned.
  • show(numRows: Int, truncate: Boolean) – show the top numRows rows in tabular form. If truncate is true then strings longer than 20 characters will be truncated and cells become right-aligned.
  • show(numRows: Int, truncate: Int) – show the top numRows rows in tabular form. If truncate is more than 0 then strings longer than truncate characters will be truncated and cells become right-aligned.

The printSchema() will print out the schema in a tree format to the console.

DataFrame Operations

The Dataset interface allows us to execute operations on the data via an SQL-based DSL or by simply running SQL queries programmatically. As mentioned before the DataFrame is simply a Dataset of Rows thus it is not strongly typed. This is why the operations are untyped.

The import spark.implicits._ contains implicits that let us use a richer notation when operating on the tables.

Untyped Operation

A simple example of filtering by the value of someColumn and then selecting anotherColumn as a result to be shown:

val result = dataFrame.filter($"someColumn" > 0).select("anotherColumn")

The $ operator is part of the spark.implicits package and lets us create a Column reference from a String.

A comprehensive list of available operations can be found in the Dataset API documentation here.

There is also a very comprehensive set of string manipulation and math function available. The list of them can be found here.

Running SQL Queries

We also have the option of running a SQL query programmatically with the sqlmethod that takes the string with the query string as the argument.

But to do that we need to first register the DataFrame as a SQL Temporary View. This will make the DataFrame table be visible from the SQL query. This can be done with the createOrReplaceTempView method, e.g.:


And now running a SQL query with the sql method:

val result = spark.sql("SELECT * FROM dataFrameTable")

The temporary view is session-scoped thus it will disappear when the session terminates. We can create a Global Temporary View that will be shared among all sessions and kept alive until the application terminates. The global temporary views are tied to the global_temp database thus to access them we must use the qualified name to refer it by using the global_temp. prefix. An example of creating and accessing such a view:


val result = spark.sql("SELECT * FROM global_temp.globalDataFrameTable")


The Pipeline concept revolves around the idea of providing a uniform API to create and compose together machine learning data-transformation pipelines to create a single, concise workflow. It also provides a possibility to persist them and use already existing one that we created and saved earlier. The concept is analogous to stream-processing in e.g.Akka Streams.

A Pipeline can consist of the following elements:

  • Transformer – an abstraction of DataFrame transformers. Consists of a transform function that maps DataFrame into a new one by e.g. adding a column, changing the rows of a specific column, predicting the label based on the feature vector.
  • Estimator – an abstraction of algorithms that fit or train or data (e.g. regression algorithms). Consists of a fit function that maps DataFrame into a Model.

Additionally Transformer and Estimator share a common API for specifying their parameters – Parameter as an alternative to using setters. More information about the Parameter concept can be found here


A Pipeline in its essence is an ordered array of stages. As mentioned before, a stage is either a Transformer or an Estimator. Of course we can easily tell just from looking at the domain and co-domain of both that a Pipeline can consist of many Transform stages but only one Estimator stage that must be at the end of the Pipeline.

An exemplary Pipeline for some simple regression task:

1. Converting categorical features into indexes.

2. Normalizing the vectors in the frame.

3. Linear regression.


We can easily save a created Pipeline or Model for later use. Not all Transform and Estimator types are supported so checking their docs for specific information about it is a good idea. Most of the basic transformers and models are supported. The methods:

  • save(path: String) – save the Model/Pipeline to the location pointed by path
  • load(path: String) – load a Model/Pipeline from the location pointed by path


Here is a short example of how to create a Pipeline (note that the setStages method takes an Array as the argument):

val indexer = new VectorIndexer()

val normalizer = new Normalizer()

val lr = new LinearRegression()

val pipeline = new Pipeline()
  .setStages(Array(indexer, normalizer, lr))

Transformers and Estimators in Spark

MLlib comes with an extensive set of Transformer and algorithm Estimatorelements that we can use in our machine learning workflows. The documentation provided for each of them is excellent and I suggest checking it out. You can find it under the following links:

The regression/classification algorithms in the library operate on two Double-value vectors – the feature vector and the label vector. Thus for categorical values, we need to transform the columns using an indexer and the multiple feature column values need to be collected into a single vector (e.g. by using a VectorAssembler).

Spark also offers us a way to define our own Transformer and Estimatorcomponents if the ones provided aren’t enough. For further information, I would suggest reading Extending the Pipeline by Tomasz Sosiński.


Finally, I would like to present an example of a full-fledged code for doing regression on a real-world dataset (we’ll be only looking at a small portion of it).

We’ll try to tackle a regression problem of predicting the price of wine based on two variables – WineEnthusiast rating and the country where it was made. We’ll use this data set for doing so. The unpacked file is renamed to wine-data.csv and moved to the application’s working directory.

The WineEnthusiast variable is closer in definition to an ordinal variable if you look at its values and variable description but we’ll treat it as a ‘Double’ for the sake of the example. Country is a categorical (nominal) value and because of that needs to be indexed for the feature vector. Then we’ll collect new columns into a single vector namedfeatures using the VectorAssembler .

import{StringIndexer, VectorAssembler}
import org.apache.spark.sql.types.{DoubleType, StringType, StructField, StructType}
import org.apache.spark.sql.{Encoders, SparkSession}

object Main {

    def main(args: Array[String]) = {

        val spark = SparkSession.builder
            .appName("Wine Price Regression")

        //We'll define a partial schema with the values we are interested in. For the sake of the example points is a Double
        val schemaStruct = StructType(
            StructField("points", DoubleType) ::
            StructField("country", StringType) ::
            StructField("price", DoubleType) :: Nil

        //We read the data from the file taking into account there's a header.
        //na.drop() will return rows where all values are non-null.
        val df =
            .option("header", true)

        //We'll split the set into training and test data
        val Array(trainingData, testData) = df.randomSplit(Array(0.8, 0.2))

        val labelColumn = "price"

        //We define two StringIndexers for the categorical variables

        val countryIndexer = new StringIndexer()

        //We define the assembler to collect the columns into a new column with a single vector - "features"
        val assembler = new VectorAssembler()
            .setInputCols(Array("points", "countryIndex"))

        //For the regression we'll use the Gradient-boosted tree estimator
        val gbt = new GBTRegressor()
            .setPredictionCol("Predicted " + labelColumn)

        //We define the Array with the stages of the pipeline
        val stages = Array(

        //Construct the pipeline
        val pipeline = new Pipeline().setStages(stages)

        //We fit our DataFrame into the pipeline to generate a model
        val model =

        //We'll make predictions using the model and the test data
        val predictions = model.transform(testData)

        //This will evaluate the error/deviation of the regression using the Root Mean Squared deviation
        val evaluator = new RegressionEvaluator()
            .setPredictionCol("Predicted " + labelColumn)

        //We compute the error using the evaluator
        val error = evaluator.evaluate(predictions)




I hope that the article helped with understanding the basics behind MLlib and how to utilize it in your machine-learning adventures.

As we can see the library (and Spark in general) provides us with a well-designed API and workflow for machine learning. Of course, there is much to learn if one wanted to explore this topic more. But, as I mentioned before, Spark provides us with a great documentation that lets us pursue it.

In the last section, I’ve attached some links that, in my opinion, should be useful for expanding our knowledge further.
Happy coding ;)

Cheers, Marcin

Useful Links

Read more

Download e-book:

Scalac Case Study Book

Download now


Marcin Gorczyński

I have a broad knowledge and extensive experience in developing high and low-level software using C, Python, Java, JavaScript and Scala. Haskell and FP enthusiast. My main non-IT hobbies and interests include music, playing guitar(s) and piano, politics, philosophy and ancient history (mainly Greek).

Latest Blogposts

07.06.2024 / By  Arkadiusz Kaczyński

Single tenant vs multitenancy – choosing the optimal solution.

Choosing between single tenant and multitenancy

What is Tenancy? Tenancy, what truly is it for? There is often a business need that involves using ecosystems by multiple organisations/clients and each of them wants their data to be separate from each other. You can achieve this with tenancy. You can do it with either single tenant deployment (setup per organisation) or with […]

06.06.2024 / By  Michał Talaśka

Java outsourcing projects: how to ensure security and compliance.

Java Outsourcing Development

In today’s world, security and compliance are paramount. A day without news of a data breach is quite rare. When it comes to outsourcing Java projects – one of our specialties – safety should be a priority. With the growing complexity and sophistication of cyber threats, businesses need to make sure that their Java outsourcing […]

30.05.2024 / By  Matylda Kamińska

Scalendar June 2024

Scalendar Scala conferences 2024

Event-driven Newsletter Welcome to June Scalendar! Join us in exploring conferences, meetups, and gatherings that promise to enrich your knowledge, expand your professional network, and inspire your career path. From Tokyo to Atlanta, Vienna to Rome, experts and enthusiasts from the global tech community come together to share knowledge, experiences and – last but not […]

software product development

Need a successful project?

Estimate project