cancel
Showing results for 
Search instead for 
Did you mean: 

What is a decision tree and how to use it?

Lorenzo BRACCO
ST Employee
In this article, the first of a series of MLC AI tips articles, we will explain what decision trees are and how to embed a solution based on decision trees into the latest generation of ST MEMS sensors, which feature a built-in machine learning core (MLC) that allows to run decision trees classifiers inside the sensor itself, without the need for an MCU.

1. Introduction

In this article, the first of a series of MLC AI tips articles, we will explain what decision trees are and how to embed a solution based on decision trees into the latest generation of ST MEMS sensors, which feature a built-in machine learning core (MLC) that allows to run decision trees classifiers inside the sensor itself, without the need for an MCU.

The products that include MLC are easily recognizable by the “X” at the end of the product name (e.g., LSM6DSOX, LSM6DSRX, etc.).  
More details about the MLC featured in these devices can be found in the respective application note documents (AN5259 for LSM6DSOX, AN5393 for LSM6DSRX, AN5392 for ISM330DHCX and AN5536 for IIS2ICLX).

1116.png

2. Key steps behind AI projects

The flow for developing a machine learning classification algorithm usually follows a set of steps:

  1. Data collection, to capture data related to the different classes to distinguish
  2. Data labeling, to assign the correct class to each set of data (for supervised learning)
  3. Model training and validation (e.g., decision tree), based on the collected and labeled data
  4. Model deployment and real-time validation, for ST MEMS devices using MLC

The image below shows how these steps translate to the development of a solution for ST MEMS sensors using MLC. In this case, the first two steps can be addressed with, for example, ST Unico-GUI to collect data and label it by storing each class in specific log files. Then a decision tree model can be created and trained using may different tools, such as WEKA, RapidMiner, MATLAB, Python (e.g., using scikit-learn) or the Unico-GUI built-in decision tree generation feature. Finally, Unico-GUI can also be used to generate the MLC configuration for ST devices, load it and test it in real time.
 

1117.png

3. What is a decision tree

A decision tree is a binary tree structure made of nodes and leaves (the nodes that do not have children). A decision node has two branches or children. The leaf node represents the classification output or the decision. At each decision node, a split on the data is performed based on the threshold of one of the input features. Usually, a value less than the threshold would select the left branch, otherwise the right branch is chosen. While traversing through the decision nodes, more and more splits on the data are made, and this process stops when a leaf is reached and the decision is made. Each leaf is associated with a class label and the paths from the root to leaf represent the classification rules.
The picture below shows an example of decision tree algorithm applied to the real estate business.

 

1118.png

3.1. Generation of decision trees

There are different algorithms to generate decision tree classifiers. The conceptual difference between these algorithms is not only in the way the splitting at each node is performed but it also differs in terms of pruning criteria and stopping criteria. The basic steps needed for the generation of a decision tree are:

  1. Splitting (node): splitting criteria such as Gini, Information Gain, or Entropy allow to maximize the information (maximizing classes separation) using one of the defined features. The process involves selecting the best feature and the corresponding threshold.
     

    1119.png

  2. Pruning: the process of decision tree pruning involves removing certain nodes in the final decision tree to reduce the size and avoid overfitting.

  3. Stopping criteria: decision trees can grow in depth to achieve 100% accuracy, but it leads to overfitting problem. All algorithms use different criteria to stop the further splitting of nodes. The criteria can be a minimum number of samples in each node or minimum gain for further splitting.

Each decision tree training algorithm has certain advantages, but C4.5 is one of the most recent algorithms , it is robust against outliers in datasets and generates a balanced decision tree.
 

3.2. Decision tree vs random forest

Random forest is another classification method. Many decision trees can produce more accurate predictions than just one single decision tree because they are generated using different criteria. With random forest several slightly differently trained decision trees are built and then merged to get more accurate and stable predictions, compared to a single decision tree. The following figure illustrates the concept.
 

1120.png

One implementation of random forest is summarized as in the following. For each decision tree that is used in the random forest, the training data is randomly selected from the available dataset and different sets of features are chosen to get a unique decision tree. The final output of the classification will be an average of the outputs from the different decision trees. Hence, random forest does not suffer from overfitting and it is possible to get a better, more accurate and stable prediction. With MLC we can replicate a similar behavior since it can allow for running up to 8 decision trees simultaneously to classify the same problem. Then, the results of these trees can be merged and processed on a microcontroller.
 

3.3. Build and visualize decision tree

In Python, it is possible to visualize the decision tree after training. The following is the code which can be used to implement the necessary steps, including:

  • Splitting data in training and testing data
  • Building the decision tree
  • Visualization of the tree
import matplotlib.pyplot as plt
from sklearn.tree import DecisionTreeClassifier
from sklearn.externals.six import StringIO
from sklearn.tree import export_graphviz

X_train, X_test, y_train, y_test = train_test_split(
    segm, labels, test_size = 0.3, random_state = 101)

clf_entropy = DecisionTreeClassifier(criterion =
    "entropy", random_state = 100)
dot_data = StringIO()
export_graphviz(clf_entropy, out_file=dot_data, 
                filled=True, rounded=True,
                special_characters=True)
graph = pydotplus.graph_from_dot_data(dot_data.getvalue())

This is quite a simple example with only one node and two leaves. The code above will generate a graphical representation of the decision tree as shown below.

Example of decision tree plot in Python

1121.png

The first line in the node represents the condition on a certain feature to split the data. The next line about entropy represents the variation of response (classes). If there is a single class, then the entropy field will be zero. The samples field represents the total samples that reach the node or leaf. The value field informs on the number of samples for each class. We should notice that the entropy after each split is reduced, and the higher the reduction, higher the importance of the node.
If the decision tree is small, visualization may allow to understand the rules/conditions in the nodes. and if they are set as expected.
 

3.4. Predictions and confusion matrix

A confusion matrix can be used to summarize and display the classification performance of a decision tree. This matrix provides an overview of correct classifications and confusion between multiple classes. Since there can be different definitions to express the accuracy of a classification algorithm, the approach to use a confusion matrix is the most straightforward way to avoid misunderstanding about “accuracy”. The following figure shows an example of confusion matrix.

Confusion matrix for audio spatial environment classification

1122.png

In the example above, the rows are the true classes and the columns are the predicted classes. Each cell contains the number of instances of the class corresponding to the row predicted by the model as the class corresponding to the column. This generates a mapping of the classifier between the ground truth and the prediction results. Here, we can see that the algorithm is confused between the “street” class and “stadium” class. The fifth cell of the last row in the confusion matrix shows that the algorithm incorrectly predicted 100 samples of the “street” class as the “stadium” class. The following Python code allows to build the confusion matrix as in the figure above.

import matplotlib.pyplot as plt
from sklearn.tree import DecisionTreeClassifier
from sklearn.externals.six import StringIO
from sklearn.tree import export_graphviz
from sklearn.metrics import confusion_matrix

def confusion_matrix(y_test, y_pred, class_list):
    confusion_matrix = metrics.confusion_matrix(y_test, y_pred)
    fig_confuse = plt.figure(figsize=(4, 3.5))
      sns.heatmap(confusion_matrix, xticklabels=
sorted(class_list), yticklabels = sorted(class_list), annot=True, fmt="d");
    plt.title("Confusion matrix")
    plt.ylabel('True label')
    plt.xlabel('Predicted label')
    plt.show();

class_list = ["beach", "mall", "nature", "office", "stadium", "street"]
y_pred_entropy = clf_entropy.predict(X_test)
confusion_matrix(y_test, y_pred_entropy, class_list)
 
Version history
Last update:
‎2021-06-08 01:42 AM
Updated by: