from pydotplus import graph_from_dot_data
from sklearn.tree import export_graphviz
from IPython.display import Image 

dot_data = export_graphviz(tree,filled=True,rounded=True,class_names=['Setosa','Versicolor','Virginica'],feature_names=['petal length','petal width'],out_file=None)
graph = graph_from_dot_data(dot_data)
Image(graph.create_png())

Program terminated with status:

 1. stderr follows: 'C:\Users\En' is not recognized as an internal or external command,
operable program or batch file.

it seems that it split my username into half.How do i overcome this?

1

There are 1 best solutions below

0
On

I have a very similar example that I'm trying out, it's based on a ML how-to book which is working with a Taiwan Credit Card dataset predicting default risk. My setup is as follows:

from six import StringIO
from sklearn.tree import export_graphviz
from IPython.display import Image 
import pydotplus

Then creating the decision tree plot is done in this way:

dot_data = StringIO()
export_graphviz(decision_tree=class_tree,
                out_file=dot_data,
                filled=True,
                rounded=True,
                feature_names = X_train.columns,
                class_names = ['pay','default'],
                special_characters=True)
graph = pydotplus.graph_from_dot_data(dot_data.getvalue()) 
Image(graph.create_png())

I think it's all coming from the out_file=dot_data argument but cannot figure out where the file path is created and stored as print(dot_data.getvalue()) did not show any pathname.

In my research I came across sklearn.plot_tree() which seems to do everything that the graphviz does. So I took the above exporet_graphviz arguments and were matching arguments were in the .plot_tree method I added them.

I ended up with the following which created the same image as was found in the text:

from sklearn import tree

plt.figure(figsize=(20, 10))
tree.plot_tree(class_tree, 
               filled=True, rounded=True, 
               feature_names = X_train.columns,
               class_names = ['pay','default'],
               fontsize=12)
plt.show()