MLR package hyperparameter tuning of random forest with 'predict.type = prob'

42 Views Asked by At

I want to build a random forest model in R and in order to find the perfect hyperparameters I want to use the MLR-Package to do an automated hyperparameter tuning. It's a classification problem and therefore I want the model to predict the probabibilities of the classes of the outcome variable "kategorie_who". In order to create a learner I use the makeLearner-function with 'classif.randomForest'.

Here's the code:

library(mlr)
library(ParamHelpers)
library(randomForest)

set.seed(123)

#create a task
trainTask <- makeClassifTask(data = training_data, target = "kategorie_who")

#create a learner
rf <- makeLearner("classif.randomForest", 
                  predict.type = "prob", 
                  par.vals = list(ntree = 100,
                                  mtry = floor((ncol(training_data)-1)/3)))
rf$par.vals <- list(importance = TRUE)

#set tune parameters
#grid search to find hyper parameters
rf_param <- makeParamSet(
  makeIntegerParam("ntree", lower = 50, upper = 500), 
  makeIntegerParam("mtry", lower = floor((ncol(training_data)-1)/3),
                   upper = ncol(training_data)))

#random search for 10 iterations
rancontrol <- makeTuneControlRandom(maxit = 10L)

#set 3 fold cross validation
set_cv <- makeResampleDesc("CV", iters = 3L)

#hyper tuning
rf_tune <- tuneParams(learner = rf,
                      resampling = set_cv,
                      task = trainTask,
                      par.set = rf_param, 
                      control = rancontrol,
                      measures = acc)

Now here's the problem: When I set 'predict.type = "response"' everything works out well. But since I want the class probabilities as an outcome I set 'predict.type = prob' and when running the code I get the following error:

Error in checkPredictLearnerOutput(.learner, .model, p) : 
  predictLearner for classif.randomForest has returned not the class levels as column names: 1.1_2,1.1_3,1.1_4,1.3_4,1.4_1,1.4_2,1.4_3,1.5_1,1.6_1,1.6_3,12_Nicht anwendbar,13_Nicht anwendbar,14_Nicht anwendbar,2_1,6_1,7.1_Unsicher,7.2_Unsicher,9.1_Nicht anwendbar,unbekannt_Unsicher

I tried changing the names of the classes into better understandable names with the make.names-function but it didn't help and I don't know what I could change. Do you have any ideas what causes the error and how to fix it?

0

There are 0 best solutions below