fbpx
Build a Multi-Class Support Vector Machine in R Build a Multi-Class Support Vector Machine in R
Support Vector Machines (SVMs) are quite popular in the data science community. Data scientists often use SVMs for classification tasks, and... Build a Multi-Class Support Vector Machine in R

Support Vector Machines (SVMs) are quite popular in the data science community. Data scientists often use SVMs for classification tasks, and they tend to perform well in a variety of problem domains. An SVM performs classification tasks by constructing hyperplanes in a multidimensional space that separates cases of different class labels. You can use an SVM when your data has exactly two classes, e.g. binary classification problems, but in this article we’ll focus on a multi-class support vector machine in R. 

[Related Article: Why Word Vectors Make Sense in Natural Language Processing]

The code below is based on the svm() function in the e1071 package that implements the SVM supervised learning algorithm. After reading this article, I strongly recommend the 2006 Journal of Statistical Software paper, “Support Vector Machines in R.

A Primer on SVM Theory

An SVM classifies data by determining the optimal hyperplane that separates observations according to their class labels. The central concept is to accommodate classes separable by linear and non-linear class boundaries.

  • “Hyperplane classifiers” are classification problems based on drawing separating lines to distinguish between members of different classes.
  • The “optimal” hyperplane is the one with the largest margin between classes.
  • “Margin” means the width of the slice parallel to the hyperplane that has no interior data points.
  • The “support vectors” are the data points closest to the separating hyperplane. They support the largest margin hyperplane in that if these points moved slightly, the hyperplane would also move.

Most classification problems require complex decision boundaries in order to make an optimal separation, i.e. correctly classify new observations in the test set based on observations in the training set.

If a data distribution is essentially non-linear, data scientists’ primary strategy is to transform the data to a higher dimension. There, hopefully, the data will be linearly separable. In this case, the original data points of one class are mapped using mathematical functions known as “kernels.” The mapped data points allow the SVM to find the optimal line to separate classes, rather than constructing a complex curve.

Multi-Class SVM Example

Now let’s turn to an example of a multi-class SVM. The code below divides R’s familiar iris dataset into training and testing sets. We train a model using the training set.

Training

In the call to svm() we’ll use the formula Species~. which indicates we want to classify the Species response variable using the 4 other predictors found in the data set. We also specify the type of usage we’d like for svm()with type=”C-classification” for classification (the default). svm() can also be used for regression problems.

The kernel argument has a variety of possible types including linear, polynomial, radial, and sigmoid. We use kernel=”radial” (the default) for this multi-class classification problem.

You can tune the operation of svm() with two additional arguments: gamma and cost, where gamma is the argument for use by the kernel function, and cost allows us to specify the cost of a violation to the margin. When cost is small, the margins will be wide, resulting in many support vectors. You can experiment with different values of gamma and cost to find the best classification accuracy.

> library(e1071)   

> data(iris)
> n <- nrow(iris)  # Number of observations
> ntrain <- round(n*0.75)  # 75% for training set
> set.seed(314)    # Set seed for reproducible results

> tindex <- sample(n, ntrain)   # Create a random index
> train_iris <- iris[tindex,]   # Create training set
> test_iris <- iris[-tindex,]   # Create test set

> svm1 <- svm(Species~., data=train_iris, 
          method="C-classification", kernal="radial", 
          gamma=0.1, cost=10)

The summary() function for SVM provides useful information regarding how the model was trained. We see that the model found 22 support vectors distributed across the classes: 10 for setosa, 3 for versicolor, and 9 for virginica.

> summary(svm1)
Call:
  svm(formula = Species ~ ., data = train_iris, type = "C-classification", kernal = "radial", gamma = 0.1, cost = 10)

Parameters:
  SVM-Type:  C-classification 
SVM-Kernel:  radial 
cost:  10 
gamma:  0.1 

Number of Support Vectors:  22
( 10 3 9 )

Number of Classes:  3 
Levels: 
  setosa versicolor virginica

 

You can also learn something when you display the actual 22 support vectors (only the first 5 shown below) calculated by SVM using the svm1$SV component of the fitted model. The output includes the observation index and coefficients of the predictors for the support vectors.

> svm1$SV   

    Sepal.Length Sepal.Width Petal.Length Petal.Width
71   0.006532661  0.30579724 0.5624153   0.7747208
84   0.124120557 -0.84094241 0.7315157   0.5125972
88   0.476884245 -1.75833413 0.3369481   0.1194119
86   0.124120557  0.76449310 0.3933149   0.5125972
53   1.182411621  0.07644931 0.6187821   0.3815354

 

The svm() algorithm also has a special plot() function that we can use to visualize the support vectors (shown with “x”), the decision boundary, and the margin for the model. The plot helps to visualize a two-dimensional projection of the data (using the Petal.Width and Petal.Length predictors) with Species classes (shown in different shadings) and support vectors. We can also use the slice argument to specify a list of named values for the dimensions held constant (useful when more than two variables are used).

> plot(svm1, train_iris, Petal.Width ~ Petal.Length,
          slice=list(Sepal.Width=3, Sepal.Length=4))

Vector Machine in R

Testing

[Related Article: Training and Visualising Word Vectors]

Now we can use the predict() function with the trained SVM model to make predictions using the test set. The result is delivered as a factor variable containing the predicted classes for each observation in the test set. Next, we can use the table() function to create a “confusion matrix” for checking the accuracy of the model. We see there was only one misclassification.

> prediction <- predict(svm1, test_iris)
> xtab <- table(test_iris$Species, prediction)
> xtab

          prediction
           setosa versicolor virginica
setosa         20         0          0
versicolor      0        20          1
virginica       0         0         19

 

Finally, we can check the accuracy of the algorithm with the following R code. The metric shows how well the trained algorithm makes predictions using the test set. The 98.3% accuracy is very good.

> (20+20+19)/nrow(test_iris)  # Compute prediction accuracy
[1] 0.9833333

Daniel Gutierrez, ODSC

Daniel D. Gutierrez is a practicing data scientist who’s been working with data long before the field came in vogue. As a technology journalist, he enjoys keeping a pulse on this fast-paced industry. Daniel is also an educator having taught data science, machine learning and R classes at the university level. He has authored four computer industry books on database and data science technology, including his most recent title, “Machine Learning and Data Science: An Introduction to Statistical Learning Methods with R.” Daniel holds a BS in Mathematics and Computer Science from UCLA.

1