I'm trying to run a GridSearchCV
over a DecisionTreeClassifier
, with the only hyper-parameter being max_depth
. The two versions I ran this with are:
max_depth = range(1,20)
The best_estimator_ attribute
shows a max_depth of 15 while the scoring function shows 0.8880 on the test set
max_depth = range(1,15)
The best_estimator_ attribute
shows a max_depth of 10 with a higher score of 0.8907.
My question is, why doesn't GridSearchCV
pick a max_depth of 10 the first time around if it gives a better score?
The code is as follows:
from sklearn.grid_search import GridSearchCV
from sklearn.metrics import make_scorer
clf = tree.DecisionTreeClassifier(random_state=7)
parameters = {"max_depth": range(1,20), "random_state":[7]}
scorer = make_scorer(fbeta_score,beta=0.5)
grid_obj = GridSearchCV(estimator=clf,param_grid=parameters,scoring=scorer)
grid_fit =grid_obj.fit(X_train,y_train)
best_clf = grid_fit.best_estimator_
predictions = (clf.fit(X_train, y_train)).predict(X_test)
best_predictions = best_clf.predict(X_test)
# Report the before-and-afterscores
print best_clf
print "\nOptimized Model\n------"
print "Final accuracy score on the testing data:
{:.4f}".format(accuracy_score(y_test, best_predictions))
print "Final F-score on the testing data: {:.4f}".format(fbeta_score(y_test,
best_predictions, beta = 0.5))
Your Question
My Answer (as I understand it, that I've acquired through too many past sources to cite)
The deeper the tree goes, the more intricacies about the training data it learns. This is called "overfitting" where it learns the training data really well but might not generalize well on unseen data. Why is the default hyperparameter
max_depth=3
? That is a design decision by the sklearn team.But why
max_depth=3
?The developers probably determine this by considering a default value that is applicable to most use-cases. They also might have determined that 3 generalizes better on unseen data.
Decision trees are random
You won't get the same
best_estimator_
every time you re-run. Try usingrandom_state
to make it repeatable each time.