fbpx
How to visualize decision trees in Python How to visualize decision trees in Python
Decision tree classifier is the most popularly used supervised learning algorithm. Unlike other classification algorithms, decision tree classifier in not a... How to visualize decision trees in Python

Decision tree classifier is the most popularly used supervised learning algorithm. Unlike other classification algorithms, decision tree classifier in not a black box in the modeling phase.  What that’s means, we can visualize the trained decision tree to understand how the decision tree gonna work for the give input features.

So in this article, you are going to learn how to visualize the trained decision tree model in Python with graphviz. So let’s begin with the table of contents.

How to visualize decision tree in Python

CLICK TO TWEET

Table of contents

  • Basic introduction to decision tree classifier
  • Fruit classification with decision tree classifier
  • Why we need to visualize the trained decision tree
  • Understand the visualized decision tree
  • Visualize decision tree in python
    • What is Graphviz
    • Visualize the decision tree online
    • Visualize the decision tree as pdf

Introduction to Decision tree classifier

Decision tree classifier is mostly used classification algorithm because of its advantages over other classification algorithms. When we say the advantages it’s not about the accuracy of the trained decision tree model. It’s all about the usage and understanding of the algorithm.

Decision tree advantages:

  • Implementation wise building decision tree algorithm is so simple.
  • The trained decision tree can use for both classification and regression problems.
  • Complexity-wise decision tree is logarithmic in the number observation in the training dataset.
  • The trained decision tree can visualize.

As we known the advantages of using the decision tree over other classification algorithms. Now let’s look at the basic introduction to the decision tree.

If you go through the article about the working of decision tree classifier in machine learning. You could aware of the decision tree keywords like root node, leaf node, information gain, gini index, tree pruning ..etc

The above keywords used to give you the basic introduction to decision tree classifier. If new to decision tree classifier, Please spend some time on the below articles before you continue reading about how to visualize the decision tree in Python.

Decision tree classifier is a classification model which creates set of rules from the training dataset. Later the created rules used to predict the target class. To get a clear picture of the rules and the need of visualizing decision, Let build a toy kind of decision tree classifier. Later use the build decision tree to understand the need to visualize the trained decision tree.

Fruit classification with decision tree classifier

fruit classification with decision tree

Fruit classification with decision tree

The decision tree classifier will train using the apple and orange features, later the trained classifier can be used to predict the fruit label given the fruit features.

The fruit features is a dummy dataset. Below are the dataset features and the targets.

Weight (grams) Smooth (Range of 1 to 10) Fruit
170 9 1
175 10 1
180 8 1
178 8 1
182 7 1
130 3 0
120 4 0
130 2 0
138 5 0
145 6 0

The dummy dataset having two features and target.

  • Weight: Is the weight of the fruit in grams
  • Smooth: Is the smoothness of the fruit in the range of 1 to 10
  • Fruit: Is the target 1 means for apple and 0 means for orange.

Let’s follow the below workflow for modeling the fruit classifier.

  • Loading the required Python machine learning packages
  • Create and load the data in Pandas dataframe
  • Building the fruit classifier with decision tree algorithm
  • Predicting the fruit type from the trained classifier

Loading the required Python machine learning packages

The required python machine learning packages for building the fruit classifier are Pandas, Numpy, and Scikit-learn

  • Pandas: For loading the dataset into dataframe, Later the loaded dataframe passed an input parameter for modeling the classifier.
  • Numpy: For creating the dataset and for performing the numerical calculation.
  • Sklearn: For training the decision tree classifier on the loaded dataset.

Now let’s create the dummy data set and load into the pandas dataframe

Create and load the data in Pandas dataframe

  • The empty pandas dataframe created for creating the fruit data set.
  • Using the numpy created arrays for target, weight, smooth.
    • The target having two unique values 1 for apple and 0 for orange.
    • Weight is the weight of the fruit in grams.
    • Smooth is the smoothness of the fruit in the range of 1 to 10.

Now, let’s use the loaded dummy dataset to train a decision tree classifier.

Building the fruit classifier with decision tree algorithm

  • Creating the decision tree classifier instance from the imported scikit learn tree class.
  • Using the loaded fruit data set features and the target to train the decision tree model.
  • Print the trained fruit classifier.

Script Output:

Now let’s use the fruit classifier to predict the fruit type by giving the fruit features.

Predicting the fruit type from the trained classifier

Created 3 test data sets and using the trained fruit classifier to predict the fruit type and comparing with the real fruit type.

Script Output: