I want to build a random forest model in R and in order to find the perfect hyperparameters I want to use the MLR-Package to do an automated hyperparameter tuning. It's a classification problem and therefore I want the model to predict the probabibilities of the classes of the outcome variable "kategorie_who". In order to create a learner I use the makeLearner-function with 'classif.randomForest'.
Here's the code:
library(mlr)
library(ParamHelpers)
library(randomForest)
set.seed(123)
#create a task
trainTask <- makeClassifTask(data = training_data, target = "kategorie_who")
#create a learner
rf <- makeLearner("classif.randomForest",
predict.type = "prob",
par.vals = list(ntree = 100,
mtry = floor((ncol(training_data)-1)/3)))
rf$par.vals <- list(importance = TRUE)
#set tune parameters
#grid search to find hyper parameters
rf_param <- makeParamSet(
makeIntegerParam("ntree", lower = 50, upper = 500),
makeIntegerParam("mtry", lower = floor((ncol(training_data)-1)/3),
upper = ncol(training_data)))
#random search for 10 iterations
rancontrol <- makeTuneControlRandom(maxit = 10L)
#set 3 fold cross validation
set_cv <- makeResampleDesc("CV", iters = 3L)
#hyper tuning
rf_tune <- tuneParams(learner = rf,
resampling = set_cv,
task = trainTask,
par.set = rf_param,
control = rancontrol,
measures = acc)
Now here's the problem: When I set 'predict.type = "response"' everything works out well. But since I want the class probabilities as an outcome I set 'predict.type = prob' and when running the code I get the following error:
Error in checkPredictLearnerOutput(.learner, .model, p) :
predictLearner for classif.randomForest has returned not the class levels as column names: 1.1_2,1.1_3,1.1_4,1.3_4,1.4_1,1.4_2,1.4_3,1.5_1,1.6_1,1.6_3,12_Nicht anwendbar,13_Nicht anwendbar,14_Nicht anwendbar,2_1,6_1,7.1_Unsicher,7.2_Unsicher,9.1_Nicht anwendbar,unbekannt_Unsicher
I tried changing the names of the classes into better understandable names with the make.names-function but it didn't help and I don't know what I could change. Do you have any ideas what causes the error and how to fix it?