Decision tree classification is a widely used machine learning algorithm that is used to predict a categorical output variable based on one or more input variables. The algorithm works by constructing a tree-like model that maps the observations in the input space to the output variable. In this article, we will discuss how to implement decision tree classification in Python and R.


Implementing Decision tree classification in Python

Step 1: Import the Required Libraries

Before we start coding, we need to import the required libraries for implementing the decision tree classification algorithm in Python. We will be using the scikit-learn library to implement this algorithm. The scikit-learn library is a popular machine learning library in Python that provides various algorithms and tools for machine learning applications.

# import libraries
from sklearn.tree import DecisionTreeClassifier
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
Step 2: Load the Data

The second step is to load the data. In this example, we will be using the iris dataset, which is a popular dataset in machine learning. The iris dataset contains information about the sepal length, sepal width, petal length, and petal width of three different species of iris flowers. The objective is to predict the species of the iris flower based on the input variables.

# load the data
iris = load_iris()
X = iris.data
y = iris.target
Step 3: Split the Data

The third step is to split the data into training and testing datasets. We will be using 70% of the data for training and the remaining 30% for testing.

# split the data
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3)
Step 4: Train the Model

The fourth step is to train the decision tree classification model using the training data.

# train the model
clf = DecisionTreeClassifier()
clf.fit(X_train, y_train)
Step 5: Test the Model

The fifth step is to test the decision tree classification model using the testing data.

# test the model
y_pred = clf.predict(X_test)
Step 6: Evaluate the Model

The final step is to evaluate the performance of the decision tree classification model. We will be using the accuracy score to evaluate the performance of the model.

# evaluate the model
from sklearn.metrics import accuracy_score
print("Accuracy:", accuracy_score(y_test, y_pred))

Implementing Decision tree classification in R

Step 1: Load the Dataset

The first step in implementing decision tree classification is to load the dataset. For this article, we will use the iris dataset, which is a popular dataset in machine learning.

To load the iris dataset, we can use the following code:

data(iris)

This will load the iris dataset into the R environment.

Step 2: Split the Dataset into Training and Test Sets

The next step is to split the dataset into training and test sets. We will use the training set to build the decision tree, and the test set to evaluate its performance.

To split the dataset, we can use the following code:

set.seed(123)
train <- sample(nrow(iris), 0.7 * nrow(iris))
train_data <- iris[train,]
test_data <- iris[-train,]

This code will split the iris dataset into training and test sets. The set.seed function is used to ensure that the split is reproducible. We are using 70% of the data for training and 30% for testing.

Step 3: Build the Decision Tree

The next step is to build the decision tree. We will use the rpart package in R to build the decision tree.

To build the decision tree, we can use the following code:

library(rpart)
fit <- rpart(Species ~ ., data=train_data, method="class")

This code will build the decision tree using the rpart function in R. The formula Species ~ . specifies that we want to predict the Species variable using all the other variables in the dataset. The method=”class” argument specifies that we are building a classification tree.

Step 4: Visualize the Decision Tree

The next step is to visualize the decision tree. We can use the plot function in R to visualize the decision tree.

To visualize the decision tree, we can use the following code:

plot(fit, margin=0.1)
text(fit, use.n=TRUE, all=TRUE, cex=.8)

This code will create a plot of the decision tree. The margin=0.1 argument specifies that we want to add a margin around the plot. The text function is used to add labels to the nodes of the decision tree.

Step 5: Make Predictions on the Test Set

The final step is to make predictions on the test set. We will use the decision tree to make predictions on the test set, and then evaluate its performance.

To make predictions on the test set, we can use the following code:

predictions <- predict(fit, test_data, type="class")

This code will make predictions on the test set using the decision tree. The type=”class” argument specifies that we want to make class predictions.


In conclusion, decision tree classification is a powerful algorithm that can be used to predict a categorical output variable based on one or more input variables. The Python scikit-learn library and R rpart library provide an easy-to-use implementation of this algorithm.

Comments welcome!