Decision tree algorithm for categorical data

to understand better how a decision tree works, I was trying to write my own code for the ID3 algorithm. What I however found out, is that the ID3 algorithm only states how to evaluate each possible split of data in each node and choose the best split. After playing with it for some time, I found out that the basic algorithm works fine only for small instances of data as in every step it needs to try all the possible splitting points the dataset allows. So for example with numerical data on the input, it will try each unique numerical value to be the splitting point ( smaller than) of the data and evaluate the best. This scales linearly with larger datasets. However, if we have categorical data, the possible ways how to split multiple categories are exponential to the number of categories. Is there a solution to reduce the number of splits that the algorithm must try? I noticed that even well established machine learning libraries like scikit-learn did not solve this.
1 answer

This one is the BEST answer!

one of the libraries that has implemented decision trees even for categorical attributes, is Spark. How they do it, is by reducing the number of possible splits that need to be checked using a simple trick. If you have binary classification and a categorical attribute with hundreds of categories, the number of splits you need to evaluate is 2^N where N is the number of unique categories of the variable. This is not usable for large scale applications. To solve this, in Spark, they are using a little trick. Since in binary classification you have only 0 and 1 as a target, what they have done is to calculate the average target value of each unique categorical value of the attribute. Then all you need to do is to sort it by the average value of the target and then you have all the unique values of the attribute sorted in a way that those with target 0 most prevalent are on the beginning and those with target 1 prevalent are at the end. Afterwards you only need to evaluate N-1 splits, that moves the algorithm out of that dangerous O(2^N) area to O(N).

To explain it better, here is a small sample of data:
attribute, target
A 1
A 0
A 1
B 0
B 0
B 0
C 1
C 1
C 1
D 1
D 0
D 0

Without the trick, you need to try 2^4 combinations, eg. AB|CD, A|BCD, AD|BC,...
If we compute the average target value for each unique attribute value, we get the following:
A 0,66
B 0
C 1
D 0,33
With this, we can easily sort the values like this
B D A C

Now, the only splits we need to compute the information gain or gini coefficient are these three:
B|DAC; BD|AC ; BDA|C
Now its easy to find the best split from this attribute.

Comments

Thanks for this simple explanation. Such a simple trick with serious runtime implications. Will definitely take a closer look at Spark.

Sascha Ple├čberger - Mon, 12/13/2021 - 15:14 :::

Recently I used MLlib for a classification task with Random Forest too which improved my results since it attempts to select the best split among all possible splits. Great explanation.

Parinaz Momeni ... - Wed, 12/22/2021 - 00:27 :::