In this chapter, we will explain some basic concepts of machine learning (ML) that will be used in all subsequent chapters. We will start with a brief introduction to ML including basic learning workflow, ML rule of thumb, and different learning tasks. Then we will gradually cover most important ML tasks.
Also, we will discuss getting started with Scala and Scala-based ML libraries for getting a quick start for the next chapter. Finally, we get started with ML with Scala and Spark ML by solving a real-life problem. The chapter will briefly cover the following topics:
- Overview of ML
- ML tasks
- Introduction to Scala
- Scala ML libraries
- Getting started with ML with Spark ML
You'll be required to have basic knowledge of Scala and Java. Since Scala is also a JVM-based language, make sure both Java JRE and JDK are installed and configured on your machine. To be more specific, you'll need Scala 2.11.x and Java 1.8.x version installed. Also, you need an IDE, such as Eclipse, IntelliJ IDEA, or Scala IDE, with the necessary plugins. However, if you're using IntelliJ IDEA, Scala will already be integrated.
The code files of this chapter can be found on GitHub:
Check out the following video to see the Code in Action:
ML approaches are based on a set of statistical and mathematical algorithms in order to carry out tasks such as classification, regression analysis, concept learning, predictive modeling, clustering, and mining of useful patterns. Using ML, we aim to improve the whole learning process automatically such that we may not need complete human interactions, or we can at least reduce the level of such interactions as much as possible.
Tom M. Mitchell explained what learning really means from a computer science perspective:
Based on this definition, we can conclude that a computer program or machine can do the following:
- Learn from data and histories
- Improve with experience
- Iteratively enhance a model that can be used to predict outcomes of questions
Since the preceding points are at the core of predictive analytics, almost every ML algorithm we use can be treated as an optimization problem. This is about finding parameters that minimize an objective function, for example, a weighted sum of two terms such as a cost function and regularization. Typically, an objective function has two components:
- A regularizer, which controls the complexity of the model
- The loss, which measures the error of the model on the training data
On the other hand, the regularization parameter defines the trade-off between minimizing the training error and the model's complexity, in an effort to avoid overfitting problems. Now, if both of these components are convex, then their sum is also convex. So, when using an ML algorithm, the goal is to obtain the best hyperparameters of a function that return the minimum error when making predictions. Therefore, by using a convex optimization technique, we can minimize the function until it converges toward the minimum error.
Given that a problem is convex, it is usually easier to analyze the asymptotic behavior of the algorithm, which shows how fast it converges as the model observes more and more training data. The task of ML is to train a model so that it can recognize complex patterns from the given input data and can make decisions in an automated way.
Thus, inferencing is all about testing the model against new (that is, unobserved) data and evaluating the performance of the model itself. However, in the whole process and for making the predictive model a successful one, data acts as the first-class citizen in all ML tasks. In reality, the data that we feed to our machine learning systems must be made up of mathematical objects, such as vectors, so that they can consume such data. For example, in the following diagram, raw images are embedded into numeric values called feature vectors before feeding in to the learning algorithm:
Depending on the available data and feature types, the performance of your predictive model can vacillate dramatically. Therefore, selecting the right features is one of the most important steps before the inferencing takes place. This is called feature engineering, where the domain knowledge about the data is used to create only selective or useful features that help prepare the feature vectors to be used so that a machine learning algorithm works.
For example, comparing hotels is quite difficult unless we already have a personal experience of staying in multiple hotels. However, with the help of an ML model, which is already trained with quality features out of thousands of reviews and features (for example, how many stars does a hotel have, size of the room, location, room service, and so on), it is pretty feasible now. We'll see several examples throughout the chapters. However, before developing such an ML model, knowing some ML concepts is also important.
The general machine learning rule of thumb is that the more data there is, the better the predictive model. However, having more features often creates a mess, to the extent that the performance degrades drastically, especially if the dataset is high-dimensional. The entire learning process requires input datasets that can be split into three types (or are already provided as such):
- A training set is the knowledge base coming from historical or live data that is used to fit the parameters of the ML algorithm. During the training phase, the ML model utilizes the training set to find optimal weights of the network and reach the objective function by minimizing the training error. Here, the back-prop rule or an optimization algorithm is used to train the model, but all the hyperparameters are needed to be set before the learning process starts.
- A validation set is a set of examples used to tune the parameters of an ML model. It ensures that the model is trained well and generalizes toward avoiding overfitting. Some ML practitioners refer to it as a development set or dev set as well.
- A test set is used for evaluating the performance of the trained model on unseen data. This step is also referred to as model inferencing. After assessing the final model on the test set (that is, when we're fully satisfied with the model's performance), we do not have to tune the model any further, but the trained model can be deployed in a production-ready environment.
A common practice is splitting the input data (after necessary pre-processing and feature engineering) into 60% for training, 10% for validation, and 20% for testing, but it really depends on use cases. Sometimes, we also need to perform up-sampling or down-sampling on the data based on the availability and quality of the datasets.
This rule of thumb of learning on different types of training sets can differ across machine learning tasks, as we will cover in the next section. However, before that, let's take a quick look at a few common phenomena in machine learning.
When we use this input data for the training, validation, and testing, usually the learning algorithms cannot learn 100% accurately, which involves training, validation, and test error (or loss). There are two types of error that one can encounter in a machine learning model:
- Irreducible error
- Reducible error
The irreducible error cannot be reduced even with the most robust and sophisticated model. However, the reducible error, which has two components, called bias and variance, can be reduced. Therefore, to understand the model (that is, prediction errors), we need to focus on bias and variance only:
- Bias means how far the predicted value are from the actual values. Usually, if the average predicted values are very different from the actual values (labels), then the bias is higher.
- An ML model will have a high bias because it can't model the relationship between input and output variables (can't capture the complexity of data well) and becomes very simple. Thus, a too-simple model with high variance causes underfitting of the data.
The following diagram gives some high-level insights and also shows what a just-right fit model should look like:
Variance signifies the variability between the predicted values and the actual values (how scattered they are).
An ML model usually performs very well on the training set but doesn't work well on the test set (because of high error rates). Ultimately, it results in an underfit model. We can recap the overfitting and underfitting once more:
- Underfitting: If your training and validation error are both relatively equal and very high, then your model is most likely underfitting your training data.
- Overfitting: If your training error is low and your validation error is high, then your model is most likely overfitting your training data. The just-rightfit model learns very well and performs better on unseen data too.
Now we know the basic working principle of an ML algorithm. However, based on problem type and the method used to solve a problem, ML tasks can be different, for example, supervised learning, unsupervised learning, and reinforcement learning. We'll discuss these learning tasks in more detail in the next section.
Although every ML problem is more or less an optimization problem, the way they are solved can vary. In fact, learning tasks can be categorized into three types: supervised learning, unsupervised learning, and reinforcement learning.
Supervised learning is the simplest and most well-known automatic learning task. It is based on a number of predefined examples, in which the category to which each of the inputs should belong is already known, as shown in the following diagram:
The preceding diagram shows a typical workflow of supervised learning. An actor (for example, a data scientist or data engineer) performs Extraction Transformation Load (ETL) and the necessary feature engineering (including feature extraction, selection, and so on) to get the appropriate data with features and labels so that they can be fed in to the model. Then he would split the data into training, development, and test sets. The training set is used to train an ML model, the validation set is used to validate the training against the overfitting problem and regularization, and then the actor would evaluate the model's performance on the test set (that is, unseen data).
However, if the performance is not satisfactory, he can perform additional tuning to get the best model based on hyperparameter optimization. Finally, he would deploy the best model in a production-ready environment. The following diagram summarizes these steps in a nutshell:
In the overall life cycle, there might be many actors involved (for example, a data engineer, data scientist, or an ML engineer) to perform each step independently or collaboratively. The supervised learning context includes classification and regression tasks; classification is used to predict which class a data point is a part of (discrete value). It is also used for predicting the label of the class attribute. On the other hand, regression is used for predicting continuous values and making a numeric prediction of the class attribute.
In the context of supervised learning, the learning process required for the input dataset is split randomly into three sets, for example, 60% for the training set, 10% for the validation set, and the remaining 30% for the testing set.
How would you summarize and group a dataset if the labels were not given? Probably, you'll try to answer this question by finding the underlying structure of a dataset and measuring the statistical properties such as frequency distribution, mean, standard deviation, and so on. If the question is how would you effectively represent data in a compressed format? You'll probably reply saying that you'll use some software for doing the compression, although you might have no idea how that software would do it. The following diagram shows the typical workflow of an unsupervised learning task:
These are exactly two of the main goals of unsupervised learning, which is largely a data-driven process. We call this type of learning unsupervised because you will have to deal with unlabeled data. The following quote comes from Yann LeCun, director of AI research (source: Predictive Learning, NIPS 2016, Yann LeCun, Facebook Research):
The two most widely used unsupervised learning tasks include the following:
- Clustering: Grouping data points based on similarity (or statistical properties). For example, a company such as Airbnb often groups its apartments and houses into neighborhoods so that customers can navigate the listed ones more easily.
- Dimensionality reduction: Compressing the data with the structure and statistical properties preserved as much as possible. For example, often the number of dimensions of the dataset needs to be reduced for the modeling and visualization.
- Anomaly detection: Useful in several applications such as identification of credit card fraud detection, identifying faulty pieces of hardware in an industrial engineering process, and identifying outliers in large-scale datasets.
- Association rule mining: Often used in market basket analysis, for example, asking which items are brought together and frequently.
Reinforcement learning is an artificial intelligence approach that focuses on the learning of the system through its interactions with the environment. In reinforcement learning, the system's parameters are adapted based on the feedback obtained from the environment, which in turn provides feedback on the decisions made by the system. The following diagram shows a person making decisions in order to arrive at their destination. Let's take an example of the route you take from home to work:
In this case, you take the same route to work every day. However, out of the blue, one day you get curious and decide to try a different route with a view to finding the shortest path. Similarly, based on your experience and the time taken with the different route, you'd decide whether you should take a specific route more often. We can take a look at one more example in terms of a system modeling a chess player. In order to improve its performance, the system utilizes the result of its previous moves; such a system is said to be a system learning with reinforcement.
So far, we have learned the basic working principles of ML and different learning tasks. However, a summarized view of each learning task with some example use cases is a mandate, which we will see in the next subsection.
We have seen the basic working principles of ML algorithms. Then we have seen what the basic ML tasks are and how they formulate domain-specific problems. However, each of these learning tasks can be solved using different algorithms. The following diagram provides a glimpse into this:
The following diagram summarizes the previously mentioned ML tasks and some applications:
However, the preceding diagram lists only a few use cases and applications using different ML tasks. In practice, ML is used in numerous use cases and applications. We will try to cover a few of those throughout this book.
Scala is a scalable, functional, and object-oriented programming language that is most closely related to Java. However, Scala is designed to be more concise and have features of functional programming languages. For example, Apache Spark, which is written in Scala, is a fast and general engine for large-scale data processing.
Scala's success is due to many factors: it has many tools that enable succinct expression, it is very concise because you need less typing, and it therefore requires less reading, and it offers very good performance as well. This is why Spark has more support for Scala in the sense that more APIs are available that are written in Scala compared to R, Python, and Java. Scala's symbolic operators are easy to read and, compared to Java, most of the Scala codes are comparatively concise and easy to read; Java is too verbose. Functional programming concepts such as pattern matching and higher-order functions are also available in Scala.
The best way to get started with Scala is either using Scala through the Scala build tool (SBT) or to use Scala through an integrated development environment (IDE). Either way, the first important step is downloading, installing, and configuring Scala. However, since Scala runs on Java Virtual Machine (JVM), having Java installed and configured on your machine is a prerequisite. Therefore, I'm not going to cover how to do that. Instead, I will provide some useful links (https://en.wikipedia.org/wiki/Integrated_development_environment).
Java programmers normally prefer Scala when they need to add some functional programming flavor to their codes as Scala runs on JVM. There are various other options when it comes to editors. The following are some options to choose from:
- Scala IDE
- Scala plugin for Eclipse
- IntelliJ IDEA
Eclipse has several advantages using numerous beta plugins and local, remote, and high-level debugging facilities with semantic highlighting and code completion for Scala.
Although Scala is a relatively new programming language compared to Java and Python, the question will arise as to why we need to consider learning it while we have Python and R. Well, Python and R are two leading programming languages for rapid prototyping and data analytics including building, exploring, and manipulating powerful models.
But Scala is becoming the key language too in the development of functional products, which are well suited for big data analytics. Big data applications often require stability, flexibility, high speed, scalability, and concurrency. All of these requirements can be fulfilled with Scala because Scala is not only a general-purpose language but also a powerful choice for data science (for example, Spark MLlib/ML). I've been using Scala for the last couple of years and I found that more and more Scala ML libraries are in development. Up next, we will discuss available and widely used Scala libraries that can be used for developing ML applications.
MLlib is a library that provides user-friendly ML algorithms that are implemented using Scala. The same API is then exposed to provide support for other languages such as Java, Python, and R. Spark MLlib provides support for local vectors and matrix data types stored on a single machine, as well as distributed matrices backed by one or multiple resilient distributed datasets (RDDs).
RDD is the primary data abstraction of Apache Spark, often called Spark Core, that represents an immutable, partitioned collection of elements that can be operated on in parallel. The resiliency makes RDD fault-tolerant (based on RDD lineage graph). RDD can help in distributed computing even when data is stored on multiple nodes in a Spark cluster. Also, RDD can be converted into a dataset as a collection of partitioned data with primitive values such as tuples or other objects.
Spark ML is a new set of ML APIs that allows users to quickly assemble and configure practical machine learning pipelines on top of datasets, which makes it easier to combine multiple algorithms into a single pipeline. For example, an ML algorithm (called estimator) and a set of transformers (for example, a StringIndexer, a StandardScalar, and a VectorAssembler) can be chained together to perform the ML task as stages without needing to run them sequentially.
At this point, I have to inform you of something very useful. Since we will be using Spark MLlib and ML APIs in upcoming chapters too. Therefore, it would be worth fixing some issues in advance. If you're a Windows user, then let me tell you about a very weird issue that you will experience while working with Spark. The thing is that Spark works on Windows, macOS, and Linux. While using Eclipse or IntelliJ IDEA to develop your Spark applications on Windows, you might face an I/O exception error and, consequently, your application might not compile successfully or may be interrupted.
Spark needs a runtime environment for Hadoop on Windows too. Unfortunately, the binary distribution of Spark (v2.4.0, for example) does not contain Windows-native components such as winutils.exe or hadoop.dll. However, these are required (not optional) to run Hadoop on Windows if you cannot ensure the runtime environment, an I/O exception saying the following will appear:
03/02/2019 11:11:10 ERROR util.Shell: Failed to locate the winutils binary in the hadoop binary path
java.io.IOException: Could not locate executable null\bin\winutils.exe in the Hadoop binaries.
There are two ways to tackle this issue on Windows and from IDEs such as Eclipse and IntelliJ IDEA:
- Download winutls.exe from https://github.com/steveloughran/ winutils/tree/ master/hadoop-2. 7. 1/bin/.
- Download and copy it inside the bin folder in the Spark distribution—for example, spark-2.2.0-bin-hadoop2.7/bin/.
- Select Project | Run Configurations... | Environment | New | and create a variable named HADOOP_HOME, then put the path in the Value field. Here is an example: c:/spark-2.2.0-bin-hadoop2.7/bin/ | OK | Apply | Run.
ScalNet is a wrapper around Deeplearning4J intended to emulate a Keras-like API for developing deep learning applications. If you're already familiar with neural network architectures and are coming from a JVM background, it would be worth exploring the Scala-based ScalNet library:
- GitHub (https://github.com/deeplear…/deeplearning4j/…/master/scalnet)
- Example (https://github.com/…/sc…/org/deeplearning4j/scalnet/examples)
DynaML is a Scala and JVM ML toolbox for research, education, and industry. This library provides an interactive, end-to-end, and enterprise-friendly way of developing ML applications. If you're interested, see more at https://transcendent-ai-labs.github.io/DynaML/.
Breeze is one of the primary scientific computing libraries for Scala, which provides a fast and efficient way of data manipulation operations such as matrix and vector operations for creating, transposing, filling with numbers, conducting element-wise operations, and calculating determinants.
Breeze enables basic operations based on the netlib-java library, which enables extremely fast algebraic computations. In addition, Breeze provides a way to perform signal-processing operations, necessary for working with digital signals.
- Breeze (https://github.com/scalanlp/breeze/)
- Breeze examples (https://github.com/scalanlp/breeze-examples)
- Breeze quickstart (https://github.com/scalanlp/breeze/wiki/Quickstart)
On the other hand, ScalaNLP is a suite of scientific computing, ML, and natural language processing, which also acts as an umbrella project for several libraries, including Breeze and Epic. Vegas is another Scala library for data visualization, which allows plotting specifications such as filtering, transformations, and aggregations. Vegas is more functional than the other numerical processing library, Breeze.
Whereas the visualization library of Breeze is backed by Breeze and JFreeChart, Vegas can be considered a missing Matplotlib for Scala and Spark, because it provides several options for rendering plots through and within interactive notebook environments, such as Jupyter and Zeppelin.
In this section, we'll see a real-life example of a classification problem. The idea is to develop a classifier that, given the values for sex, age, time, number of warts, type, and area, will predict whether a patient has to go through the cryotherapy.
We will use a recently added cryotherapy dataset from the UCI machine learning repository. The dataset can be downloaded from http://archive.ics.uci.edu/ml/datasets/Cryotherapy+Dataset+#.
This dataset contains information about wart treatment results of 90 patients using cryotherapy. In case you don't know, a wart is a kind of skin problem caused by infection with a type of human papillomavirus. Warts are typically small, rough, and hard growths that are similar in color to the rest of the skin.
There are two available treatments for this problem:
- Salicylic acid: A type of gel containing salicylic acid used in medicated band-aids.
- Cryotherapy: A freezing liquid (usually nitrogen) is sprayed onto the wart. It will destroy the cells in the affected area. After the cryotherapy, usually, a blister develops, which eventually turns into a scab and falls off after a week or so.
There are 90 samples or instances that were either recommended to go through cryotherapy or be discharged without cryotherapy. There are seven attributes in the dataset:
- sex: Patient gender, characterized by 1 (male) or 0 (female).
- age: Patient age.
- Time: Observation and treatment time in hours.
- Number_of_Warts: Number of warts.
- Type: Types of warts.
- Area: The amount of affected area.
- Result_of_Treatment: The recommended result of the treatment, characterized by either 1 (yes) or 0 (no). It is also the target column.
As you can understand, it is a classification problem because we will have to predict discrete labels. More specifically, it is a binary classification problem. Since this is a small dataset with only six features, we can start with a very basic classification algorithm called logistic regression, where the logistic function is applied to the regression to get the probabilities of it belonging in either class. We will learn more details about logistic regression and other classification algorithms in Chapter 3, Scala for Learning Classification. For this, we use the Spark ML-based implementation of logistic regression in Scala.
I am assuming that Java is already installed on your machine and JAVA_HOME is set too. Also, I'm assuming that your IDE has the Maven plugin installed. If so, then just create a Maven project and add the project properties as follows:
In the preceding properties tag, I specified the Spark version (that is, 2.3.0), but you can adjust it. Then add the following dependencies in the pom.xml file:
Then, if everything goes smoothly, all the JAR files will be downloaded in the project home as Maven dependencies. Alright! Then we can start writing the code.
Since you're here to learn how to solve a real-life problem in Scala, exploring available Scala libraries would be worthwhile. Unfortunately, we don't have many options except for the Spark MLlib and ML, which can be used for the regression analysis very easily and comfortably. Importantly, it has every regression analysis algorithm implemented as high-level interfaces. I assume that Scala, Java, and your favorite IDE such as Eclipse or IntelliJ IDEA are already configured on your machine. We will introduce some concepts of Spark without providing much detail, but we will continue learning in upcoming chapters too.
First, I'll introduce SparkSession, which is a unified entry point of a Spark application introduced from Spark 2.0. Technically, SparkSession is the gateway to interact with some of Spark's functionality with a few constructs such as SparkContext, HiveContext, and SQLContext, which are all encapsulated in a SparkSession. Previously, you have seen how to create such a session, probably without knowing it. Well, a SparkSession can be created as a builder pattern as follows:
val spark = SparkSession
.builder // the builder itself
.master("local") // number of cores (i.e. 4, use * for all cores)
.config("spark.sql.warehouse.dir", "/temp") // Spark SQL Hive Warehouse location
.appName("SparkSessionExample") // name of the Spark application
.getOrCreate() // get the existing session or create a new one
The preceding builder will try to get an existing SparkSession or create a new one. Then the newly created SparkSession will be assigned as the global default.
Creating a DataFrame is probably the most important task in every data analytics task. Spark provides a read() method that can be used to read data from numerous sources in various formats such as CSV, JSON, Avro, and JDBC. For example, the following code snippet shows how to read a CSV file and create a Spark DataFrame:
val dataDF = spark.read
.option("header", "true") // we read the header to know the column and structure
.option("inferSchema", "true") // we infer the schema preserved in the CSV
.format("com.databricks.spark.csv") // we're using the CSV reader from DataBricks
.load("data/inputData.csv") // Path of the CSV file
.cache // [Optional] cache if necessary
Once a DataFrame is created, we can see a few samples (that is, rows) by invoking the show() method, as well as print the schema using the printSchema() method. Invoking describe().show() will show the statistics about the DataFrame:
dataDF.show() // show first 10 rows
dataDF.printSchema() // shows the schema (including column name and type)
dataDF.describe().show() // shows descriptive statistics
In many cases, we have to use the spark.implicits._ package, which is one of the most useful imports. It is handy, with a lot of implicit methods for converting Scala objects to datasets and vice versa. Once we have created a DataFrame, we can create a view (temporary or global) for performing SQL using either the ceateOrReplaceTempView() method or the createGlobalTempView() method, respectively:
dataDF.createOrReplaceTempView("myTempDataFrame") // create or replace a local temporary view with dataDF
dataDF.createGlobalTempView("myGloDataFrame") // create a global temporary view with dataframe dataDF
Now a SQL query can be issued to see the data in tabular format:
spark.sql("SELECT * FROM myTempDataFrame")// will show all the records
To drop these views, spark.catalog.dropTempView("myTempDataFrame") or spark.catalog.dropGlobalTempView("myGloDataFrame"), respectively, can be invoked. By the way, once you're done simply invoking the spark.stop() method, it will destroy the SparkSession and all the resources allocated by the Spark application. Interested readers can read detailed API documentation at https://spark.apache.org/ to get more information.
There is a Cryotherapy.xlsx Excel file, which contains data as well as data usage agreement texts. So, I just copied the data and saved it in a CSV file named Cryotherapy.csv. Let's start by creating SparkSession—the gateway to access Spark:
val spark = SparkSession
Then let's read the training set and see a glimpse of it:
var CryotherapyDF = spark.read.option("header", "true")
Let's take a look to see if the preceding CSV reader managed to read the data properly, including header and types:
As seen from the following screenshot, the schema of the Spark DataFrame has been correctly identified. Also, as expected, all the features of my ML algorithms are numeric (in other words, in integer or double format):
A snapshot of the dataset can be seen using the show() method. We can limit the number of rows; here, let's say 5:
The output of the preceding line of code shows the first five samples of the DataFrame:
As per the dataset description on the UCI machine learning repository, there are no null values. Also, the Spark ML-based classifiers expect numeric values to model them. The good thing is that, as seen in the schema, all the required fields are numeric (that is, either integers or floating point values). Also, the Spark ML algorithms expect a label column, which in our case is Result_of_Treatment. Let's rename it to label using the Spark-provided withColumnRenamed() method:
//Spark ML algorithm expect a 'label' column, which is in our case 'Survived". Let's rename it to 'label'
CryotherapyDF = CryotherapyDF.withColumnRenamed("Result_of_Treatment", "label")
All the Spark ML-based classifiers expect training data containing two objects called label (which we already have) and features. We have seen that we have six features. However, those features have to be assembled to create a feature vector. This can be done using the VectorAssembler() method. It is one kind of transformer from the Spark ML library. But first we need to select all the columns except the label column:
val selectedCols = Array("sex", "age", "Time", "Number_of_Warts", "Type", "Area")
Then we instantiate a VectorAssembler() transformer and transform as follows:
val vectorAssembler = new VectorAssembler()
val numericDF = vectorAssembler.transform(CryotherapyDF)
As expected, the last line of the preceding code segment shows the assembled DataFrame having label and features, which are needed to train an ML algorithm:
Next, we separate the training set and test sets. Let's say that 80% of the training set will be used for the training and the other 20% will be used to evaluate the trained model:
val splits = numericDF.randomSplit(Array(0.8, 0.2))
val trainDF = splits(0)
val testDF = splits(1)
Instantiate a decision tree classifier by specifying impurity, max bins, and the max depth of the trees. Additionally, we set the label and feature columns:
val dt = new DecisionTreeClassifier()
Now that the data and the classifier are ready, we can perform the training:
val dtModel = dt.fit(trainDF)
Since it's a binary classification problem, we need the BinaryClassificationEvaluator() estimator to evaluate the model's performance on the test set:
val evaluator = new BinaryClassificationEvaluator()
Now that the training is completed and we have a trained decision tree model, we can evaluate the trained model on the test set:
val predictionDF = dtModel.transform(testDF)
Finally, we compute the classification accuracy:
val accuracy = evaluator.evaluate(predictionDF)
println("Accuracy = " + accuracy)
You should experience about 96% classification accuracy:
Accuracy = 0.9675436785432
Finally, we stop the SparkSession by invoking the stop() method:
We have managed to achieve about 96% accuracy with minimum effort. However, there are other performance metrics such as precision, recall, and F1 measure. We will discuss them in upcoming chapters. Also, if you're a newbie to ML and haven't understood all the steps in this example, don't worry. We'll recap all of these steps in other chapters with various other examples.
In this chapter, we have learned some basic concepts of ML, which is used to solve a real-life problem. We started with a brief introduction to ML including a basic learning workflow, the ML rule of thumb, and different learning tasks, and then we gradually covered important ML tasks such as supervised learning, unsupervised learning, and reinforcement learning. Additionally, we discussed Scala-based ML libraries. Finally, we have seen how to get started with machine learning with Scala and Spark ML by solving a simple classification problem.
Now that we know basic ML and Scala-based ML libraries, we can start learning in a more structured way. In the next chapter, we will learn about regression analysis techniques. Then we will develop a predictive analytics application for predicting slowness in traffic using linear regression and generalized linear regression algorithms.