

How to visualize decision trees in Python
BlogModelingTools & LanguagesAlgorithms|classification|Pandasposted by Saimadhu Polamuri May 17, 2017 Saimadhu Polamuri

Decision tree classifier is the most popularly used supervised learning algorithm. Unlike other classification algorithms, decision tree classifier in not a black box in the modeling phase. What that’s means, we can visualize the trained decision tree to understand how the decision tree gonna work for the give input features.
So in this article, you are going to learn how to visualize the trained decision tree model in Python with graphviz. So let’s begin with the table of contents.
How to visualize decision tree in Python
Table of contents
- Basic introduction to decision tree classifier
- Fruit classification with decision tree classifier
- Why we need to visualize the trained decision tree
- Understand the visualized decision tree
- Visualize decision tree in python
- What is Graphviz
- Visualize the decision tree online
- Visualize the decision tree as pdf
Introduction to Decision tree classifier
Decision tree classifier is mostly used classification algorithm because of its advantages over other classification algorithms. When we say the advantages it’s not about the accuracy of the trained decision tree model. It’s all about the usage and understanding of the algorithm.
Decision tree advantages:
- Implementation wise building decision tree algorithm is so simple.
- The trained decision tree can use for both classification and regression problems.
- Complexity-wise decision tree is logarithmic in the number observation in the training dataset.
- The trained decision tree can visualize.
As we known the advantages of using the decision tree over other classification algorithms. Now let’s look at the basic introduction to the decision tree.
If you go through the article about the working of decision tree classifier in machine learning. You could aware of the decision tree keywords like root node, leaf node, information gain, gini index, tree pruning ..etc
The above keywords used to give you the basic introduction to decision tree classifier. If new to decision tree classifier, Please spend some time on the below articles before you continue reading about how to visualize the decision tree in Python.
Decision tree classifier is a classification model which creates set of rules from the training dataset. Later the created rules used to predict the target class. To get a clear picture of the rules and the need of visualizing decision, Let build a toy kind of decision tree classifier. Later use the build decision tree to understand the need to visualize the trained decision tree.
Fruit classification with decision tree classifier
Fruit classification with decision tree
The decision tree classifier will train using the apple and orange features, later the trained classifier can be used to predict the fruit label given the fruit features.
The fruit features is a dummy dataset. Below are the dataset features and the targets.
Weight (grams) | Smooth (Range of 1 to 10) | Fruit |
170 | 9 | 1 |
175 | 10 | 1 |
180 | 8 | 1 |
178 | 8 | 1 |
182 | 7 | 1 |
130 | 3 | 0 |
120 | 4 | 0 |
130 | 2 | 0 |
138 | 5 | 0 |
145 | 6 | 0 |
The dummy dataset having two features and target.
- Weight: Is the weight of the fruit in grams
- Smooth: Is the smoothness of the fruit in the range of 1 to 10
- Fruit: Is the target 1 means for apple and 0 means for orange.
Let’s follow the below workflow for modeling the fruit classifier.
- Loading the required Python machine learning packages
- Create and load the data in Pandas dataframe
- Building the fruit classifier with decision tree algorithm
- Predicting the fruit type from the trained classifier
Loading the required Python machine learning packages
1
2
3
4
|
# Required Python Packages
import pandas as pd
import numpy as np
from sklearn import tree
|
The required python machine learning packages for building the fruit classifier are Pandas, Numpy, and Scikit-learn
- Pandas: For loading the dataset into dataframe, Later the loaded dataframe passed an input parameter for modeling the classifier.
- Numpy: For creating the dataset and for performing the numerical calculation.
- Sklearn: For training the decision tree classifier on the loaded dataset.
Now let’s create the dummy data set and load into the pandas dataframe
Create and load the data in Pandas dataframe
1
2
3
4
5
6
7
8
|
# creating dataset for modeling Apple / Orange classification
fruit_data_set = pd.DataFrame()
fruit_data_set[“fruit”] = np.array([1, 1, 1, 1, 1, # 1 for apple
0, 0, 0, 0, 0]) # 0 for orange
fruit_data_set[“weight”] = np.array([170, 175, 180, 178, 182,
130, 120, 130, 138, 145])
fruit_data_set[“smooth”] = np.array([9, 10, 8, 8, 7,
3, 4, 2, 5, 6])
|
- The empty pandas dataframe created for creating the fruit data set.
- Using the numpy created arrays for target, weight, smooth.
- The target having two unique values 1 for apple and 0 for orange.
- Weight is the weight of the fruit in grams.
- Smooth is the smoothness of the fruit in the range of 1 to 10.
Now, let’s use the loaded dummy dataset to train a decision tree classifier.
Building the fruit classifier with decision tree algorithm
1
2
3
4
5
|
fruit_classifier = tree.DecisionTreeClassifier()
fruit_classifier.fit(fruit_data_set[[“weight”, “smooth”]], fruit_data_set[“fruit”])
print “>>>>> Trained fruit_classifier <<<<<“
print fruit_classifier
|
- Creating the decision tree classifier instance from the imported scikit learn tree class.
- Using the loaded fruit data set features and the target to train the decision tree model.
- Print the trained fruit classifier.
Script Output:
1
2
3
4
5
|
>>>>> Trained fruit_classifier <<<<<
DecisionTreeClassifier(class_weight=None, criterion=‘gini’, max_depth=None,
max_features=None, max_leaf_nodes=None, min_samples_leaf=1,
min_samples_split=2, min_weight_fraction_leaf=0.0,
presort=False, random_state=None, splitter=‘best’)
|
Now let’s use the fruit classifier to predict the fruit type by giving the fruit features.
Predicting the fruit type from the trained classifier
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
|
# fruit data set 1st observation
test_features_1 = [[fruit_data_set[“weight”][0], fruit_data_set[“smooth”][0]]]
test_features_1_fruit = fruit_classifier.predict(test_features_1)
print “Actual fruit type: {act_fruit} , Fruit classifier predicted: {predicted_fruit}”.format(
act_fruit=fruit_data_set[“fruit”][0], predicted_fruit=test_features_1_fruit)
# fruit data set 3rd observation
test_features_3 = [[fruit_data_set[“weight”][2], fruit_data_set[“smooth”][2]]]
test_features_3_fruit = fruit_classifier.predict(test_features_3)
print “Actual fruit type: {act_fruit} , Fruit classifier predicted: {predicted_fruit}”.format(
act_fruit=fruit_data_set[“fruit”][2], predicted_fruit=test_features_3_fruit)
# fruit data set 8th observation
test_features_8 = [[fruit_data_set[“weight”][7], fruit_data_set[“smooth”][7]]]
test_features_8_fruit = fruit_classifier.predict(test_features_8)
print “Actual fruit type: {act_fruit} , Fruit classifier predicted: {predicted_fruit}”.format(
act_fruit=fruit_data_set[“fruit”][7], predicted_fruit=test_features_8_fruit)
|
Created 3 test data sets and using the trained fruit classifier to predict the fruit type and comparing with the real fruit type.
Script Output: