Support Vector Machines (SVMs) are quite popular in the data science community.
Data scientists often use SVMs for classification tasks, and they tend to perform well in a variety of problem domains. An SVM performs classification tasks by constructing hyperplanes in a multidimensional space that separates cases of different class labels. You can use an SVM when your data has exactly two classes, e.g. binary classification problems, but in this article we’ll focus on a multi-class support vector machine in R.
The code below is based on the
svm() function in the
e1071 package that implements the SVM supervised learning algorithm. After reading this article, I strongly recommend the 2006 Journal of Statistical Software paper, “Support Vector Machines in R.”
A Primer on SVM Theory
An SVM classifies data by determining the optimal hyperplane that separates observations according to their class labels. The central concept is to accommodate classes separable by linear and non-linear class boundaries.
- “Hyperplane classifiers” are classification problems based on drawing separating lines to distinguish between members of different classes.
- The “optimal” hyperplane is the one with the largest margin between classes.
- “Margin” means the width of the slice parallel to the hyperplane that has no interior data points.
- The “support vectors” are the data points closest to the separating hyperplane. They support the largest margin hyperplane in that if these points moved slightly, the hyperplane would also move.
Most classification problems require complex decision boundaries in order to make an optimal separation, i.e. correctly classify new observations in the test set based on observations in the training set.
If a data distribution is essentially non-linear, data scientists’ primary strategy is to transform the data to a higher dimension. There, hopefully, the data will be linearly separable. In this case, the original data points of one class are mapped using mathematical functions known as “kernels.” The mapped data points allow the SVM to find the optimal line to separate classes, rather than constructing a complex curve.
Multi-Class SVM Example
Now let’s turn to an example of a multi-class SVM. The code below divides R’s familiar
iris dataset into training and testing sets. We train a model using the training set.
In the call to
svm() we’ll use the formula
Species~. which indicates we want to classify the
Species response variable using the 4 other predictors found in the data set. We also specify the type of usage we’d like for
type=”C-classification” for classification (the default).
svm() can also be used for regression problems.
kernel argument has a variety of possible types including linear, polynomial, radial, and sigmoid. We use
kernel=”radial” (the default) for this multi-class classification problem.
You can tune the operation of
svm() with two additional arguments:
gamma is the argument for use by the kernel function, and
cost allows us to specify the cost of a violation to the margin. When cost is small, the margins will be wide, resulting in many support vectors. You can experiment with different values of
cost to find the best classification accuracy.
> library(e1071) > data(iris) > n <- nrow(iris) # Number of observations > ntrain <- round(n*0.75) # 75% for training set > set.seed(314) # Set seed for reproducible results > tindex <- sample(n, ntrain) # Create a random index > train_iris <- iris[tindex,] # Create training set > test_iris <- iris[-tindex,] # Create test set > svm1 <- svm(Species~., data=train_iris, method="C-classification", kernal="radial", gamma=0.1, cost=10)
summary() function for SVM provides useful information regarding how the model was trained. We see that the model found 22 support vectors distributed across the classes: 10 for setosa, 3 for versicolor, and 9 for virginica.
> summary(svm1) Call: svm(formula = Species ~ ., data = train_iris, type = "C-classification", kernal = "radial", gamma = 0.1, cost = 10) Parameters: SVM-Type: C-classification SVM-Kernel: radial cost: 10 gamma: 0.1 Number of Support Vectors: 22 ( 10 3 9 ) Number of Classes: 3 Levels: setosa versicolor virginica
You can also learn something when you display the actual 22 support vectors (only the first 5 shown below) calculated by SVM using the
svm1$SV component of the fitted model. The output includes the observation index and coefficients of the predictors for the support vectors.
> svm1$SV Sepal.Length Sepal.Width Petal.Length Petal.Width 71 0.006532661 0.30579724 0.5624153 0.7747208 84 0.124120557 -0.84094241 0.7315157 0.5125972 88 0.476884245 -1.75833413 0.3369481 0.1194119 86 0.124120557 0.76449310 0.3933149 0.5125972 53 1.182411621 0.07644931 0.6187821 0.3815354
svm() algorithm also has a special
plot() function that we can use to visualize the support vectors (shown with “x”), the decision boundary, and the margin for the model. The plot helps to visualize a two-dimensional projection of the data (using the Petal.Width and Petal.Length predictors) with Species classes (shown in different shadings) and support vectors. We can also use the
slice argument to specify a list of named values for the dimensions held constant (useful when more than two variables are used).
> plot(svm1, train_iris, Petal.Width ~ Petal.Length, slice=list(Sepal.Width=3, Sepal.Length=4))
Now we can use the
predict() function with the trained SVM model to make predictions using the test set. The result is delivered as a factor variable containing the predicted classes for each observation in the test set. Next, we can use the
table() function to create a “confusion matrix” for checking the accuracy of the model. We see there was only one misclassification.
> prediction <- predict(svm1, test_iris) > xtab <- table(test_iris$Species, prediction) > xtab prediction setosa versicolor virginica setosa 20 0 0 versicolor 0 20 1 virginica 0 0 19
Finally, we can check the accuracy of the algorithm with the following R code. The metric shows how well the trained algorithm makes predictions using the test set. The 98.3% accuracy is very good.
> (20+20+19)/nrow(test_iris) # Compute prediction accuracy  0.9833333