How to collect all the paths from sklearn decision tree?

357 Views Asked by At

I'm trying to generate all the paths from a decision tree in skealrn. The estimator here came from random forest, and it's a decision tree in sklearn. But I got confused by the data structure of sklearn decision tree. It seems that left, right here contains all the left nodes.

When I tried to print out the paths, it works ok.

def get_code(tree, feature_names, tabdepth=0):
    left      = tree.tree_.children_left
    right     = tree.tree_.children_right
    threshold = tree.tree_.threshold
    features  = [feature_names[i] for i in tree.tree_.feature]
    value = tree.tree_.value

    def recurse(left, right, threshold, features, node, tabdepth=0):
            if (threshold[node] != -2):
                print ('\t' * tabdepth + "if ( " + features[node] + " <= " + str(threshold[node]) + " ) {")
                if left[node] != -1:
                        recurse (left, right, threshold, features,left[node], tabdepth+1)
                print ('\t' * tabdepth + "} else {")
                if right[node] != -1:
                        recurse (left, right, threshold, features,right[node], tabdepth+1)
                print ('\t' * tabdepth + "}")
            else:
                print ('\t' * tabdepth + "return " + str(value[node]))

    recurse(left, right, threshold, features, 0)

But I need to collect all the paths in a list, also not record the path if the leaf node is "normal", so I tried the code below:

def extract_attack_rules(estimator, feature_names):
    left      = estimator.tree_.children_left
    right     = estimator.tree_.children_right
    threshold = estimator.tree_.threshold
    features  = [feature_names[i] for i in estimator.tree_.feature]
    value = estimator.tree_.value

    def recurse(left, right, threshold, features, node):
        path_lst = []

        if threshold[node] != -2:  # not leaf node
            left_cond = features[node]+"<="+str(threshold[node])
            right_cond = features[node]+">"+str(threshold[node])

            if left[node] != -1:  # not leaf node
                left_path_lst = recurse(left, right, threshold, features,left[node])
            if right[node] != -1:  # not leaf node
                right_path_lst = recurse(left, right, threshold, features,right[node])

            if left_path_lst is not None:
                path_lst.extend([left_path.append(left_cond) for left_path in left_path_lst])

            if pre_right_path is not None:
                path_lst.extend([right_path.append(right_cond) for right_path in right_path_lst])
            return path_lst

        else:  # leaf node, the attack type
            if value[node][0][0] > 0:  # if leaf is normal, not collect this path
                return None
            else:  # attack
                for i in range(len(value[node][0])):
                    if value[node][0][i] > 0:
                        return [[value[node][0][i]]]

    all_path = recurse(left, right, threshold, features, 0)

    return all_path

It returns a super giant result the there is not enough memory to load, I'm pretty sure something is wrong in the code here, because all the needed paths should not be that large. I have also tried the methods here: Getting decision path to a node in sklearn, but the output of sklearn tree structure only confused me more.

Do you know how to fix the problem here?

0

There are 0 best solutions below