Tuning HPO with SVM, poor predictions mlr3

53 Views Asked by At

I'm trying to build a SVM learner to predict the target of my task. Here are my datas

structure(list(PatientID = c("P1", "P1", "P1", "P1", "P1", "P1", "P2", "P2", "P3", "P4", "P5", "P5", "P5", "P5", "P5", "P6", "P6", "P6"), 
    LesionResponse = structure(c(2L, 1L, 2L, 2L, 1L, 2L, 2L, 
    2L, 2L, 2L, 2L, 1L, 1L, 2L, 2L, 2L, 2L, 1L),.Label = c("0", 
    "1"), class = "factor"), F1 = c(1.25, 1.25, 1.25, 1.25, 1.25, 1.25, 
    0.625, 0.625, 0.625, 0.625, 0.625, 0.625, 1.25, 0.625, 0.625, 
    1.25, 1.25, 1.25), F2 = c(1, 5, 3, 2, 1, 1, 6, 9, 0, 5, 0, 4, 4, 4, 5, 2, 1, 1), F3 = c(0, 4, 3, 1, 1, 0, 3, 8, 4, 5, 0, 4, 4, 3, 5, 2, 0, 0), F4 = c(0, 9, 0, 7, 4, 0, 3, 8, 4, 5, 9, 1, 1, 3, 5, 3, 9, 0)), row.names = c(1L, 2L, 3L, 4L, 5L, 6L, 7L, 8L, 9L, 
10L, 11L, 12L, 13L, 14L, 15L, 16L, 17L, 18L), class = "data.frame")

Here is the code I used to divide my datas, normalize them, create tasks and my learner, the code may appears "artisanal/not professional", it's normal (all my work is based on the mlr3 book, I'm autodidact and try to learn code on datas given by my hospital, I'm a medical student) :

data$PatientID <- as.factor(data$PatientID)
task <- TaskClassif$new(id = "data", backend = data, target = "LesionResponse")

set.seed(1234)

split_data <- function(task, patient_col = "PatientID", target_col = "LesionResponse", train_ratio = 0.6, valid_ratio = 0.2) {
  repeat {
    # Group data by PatientID
    patients <- task$data()[[patient_col]]
    patient_groups <- split(task$data(), patients)

    # Shuffle patient groups
    patient_groups_shuffled <- sample(patient_groups)

    # Calculate the number of samples for train and validation sets
    n_train <- round(length(patient_groups_shuffled) * train_ratio)
    n_valid <- round(length(patient_groups_shuffled) * valid_ratio)

    # Split patient groups into train, validation, and test sets
    train_groups <- patient_groups_shuffled[1:n_train]
    valid_groups <- patient_groups_shuffled[(n_train + 1):(n_train + n_valid)]
    test_groups <- patient_groups_shuffled[(n_train + n_valid + 1):length(patient_groups_shuffled)]

    # Combine patient groups back into data frames
    train_set <- do.call(rbind, train_groups)
    valid_set <- do.call(rbind, valid_groups)
    test_set <- do.call(rbind, test_groups)

    # Check the proportion of "1" in the train set
    proportion_train <- mean(train_set[[target_col]] == "1")
    proportion_valid <- mean(valid_set[[target_col]] == "1")
    proportion_test <- mean(test_set[[target_col]] == "1")

    if (proportion_train >= 0.70 && proportion_train <= 0.75 && 
        proportion_valid >= 0.70 && proportion_valid <= 0.75 && 
        proportion_test >= 0.70 && proportion_test <= 0.75) {
      break
    }
  }

  return(list(train = train_set, validation = valid_set, test = test_set))
}

# Call the function to split the data
split_data_sets <- split_data(task)
train_set <- split_data_sets$train
valid_set <- split_data_sets$validation
test_set <- split_data_sets$test

# Get the common rows between valid_set and data as numeric values
valid_rows <- as.numeric(which(data$PatientID %in% valid_set$PatientID))

# Get the common rows between train_set and data as numeric values
train_rows <- as.numeric(which(data$PatientID %in% train_set$PatientID))

# Get the common rows between test_set and data as numeric values
test_rows <- as.numeric(which(data$PatientID %in% test_set$PatientID))

data$PatientID <- NULL
train_set$PatientID <- NULL
valid_set$PatientID <- NULL
test_set$PatientID <- NULL

# Check the proportion of "1" in each set
train_proportion <- mean(train_set$LesionResponse == "1")
valid_proportion <- mean(valid_set$LesionResponse == "1")
test_proportion <- mean(test_set$LesionResponse == "1")

cat("Proportion of '1' in Train set: ", train_proportion, "\n")
cat("Proportion of '1' in Validation set: ", valid_proportion, "\n")
cat("Proportion of '1' in Test set: ", test_proportion, "\n")

# Check for common rows
common_train_valid <- intersect(train_set$PatientID, valid_set$PatientID)
common_train_test <- intersect(train_set$PatientID, test_set$PatientID)
common_valid_test <- intersect(valid_set$PatientID, test_set$PatientID)

# Print the number of rows in each set
cat("\nNumber of rows in Train set: ", nrow(train_set), "\n")
cat("Number of rows in Validation set: ", nrow(valid_set), "\n")
cat("Number of rows in Test set: ", nrow(test_set), "\n")

# Print the common rows
cat("\nCommon PatientIDs between Train and Validation sets:\n")
if (length(common_train_valid) == 0) {
  cat("Il n'y a aucun patient en commun.\n")
} else {
  cat("Ces IDs sont communs entre les deux groupes:\n")
  print(common_train_valid)
}

cat("\nCommon PatientIDs between Train and Test sets:\n")
if (length(common_train_test) == 0) {
  cat("Il n'y a aucun patient en commun.\n")
} else {
  cat("Ces IDs sont communs entre les deux groupes:\n")
  print(common_train_test)
}

cat("\nCommon PatientIDs between Validation and Test sets:\n")
if (length(common_valid_test) == 0) {
  cat("Il n'y a aucun patient en commun.\n")
} else {
  cat("Ces IDs sont communs entre les deux groupes:\n")
  print(common_valid_test)

#Normalization 
normalize_data <- function(train_set, valid_set, test_set) {
  numeric_columns <- names(train_set)[sapply(train_set, is.numeric)]

  #Calcule moyenne et déviation standard
  mean_train <- colMeans(train_set[, ..numeric_columns])
  std_train <- apply(train_set[, ..numeric_columns], 2, sd)

    train_set_normalized <- train_set
  valid_set_normalized <- valid_set
  test_set_normalized <- test_set

  train_set_normalized[, (numeric_columns) := Map(function(x, m, s) (x - m) / s, .SD, m = mean_train, s = std_train), .SDcols = numeric_columns]
  valid_set_normalized[, (numeric_columns) := Map(function(x, m, s) (x - m) / s, .SD, m = mean_train, s = std_train), .SDcols = numeric_columns]
  test_set_normalized[, (numeric_columns) := Map(function(x, m, s) (x - m) / s, .SD, m = mean_train, s = std_train), .SDcols = numeric_columns]

  return(list(train = train_set_normalized, validation = valid_set_normalized, test = test_set_normalized))
}

#Normalisation
normalized_data_sets <- normalize_data(train_set, valid_set, test_set)
train_set <- normalized_data_sets$train
valid_set <- normalized_data_sets$validation
test_set <- normalized_data_sets$test

#Creation of a common set train + validation
combined_data <- rbind(train_set, valid_set)
task_tuning <- TaskClassif$new(id = "task_tuning", backend = combined_data, target = "LesionResponse")
train_rows_task <- 1:nrow(train_set)
valid_rows_task <- (nrow(train_set) + 1):(nrow(train_set) + nrow(valid_set))

# Customisation of resampling in task_tuning
resampling_tuning = rsmp("custom")
resampling_tuning$instantiate(task_tuning, train = list(train_rows_task), test = list(valid_rows_task))

#Creation of task_test
task_test <- TaskClassif$new(id = "task_test", backend = test_set, target = "LesionResponse")
test_rows_task <- (nrow(train_set) + 1):(nrow(train_set) + nrow(test_set))

#Tuning of the SVM model 
learner = lrn("classif.svm", cost = to_tune(1e-1, 1e5), gamma = to_tune(1e-1, 1), type = "C-classification", kernel = "radial", predict_type = "prob")

instance = ti(
  task = task_tuning, 
  learner = learner, 
  resampling = resampling_tuning, 
  measures = msr("classif.ce"), 
  terminator = trm("evals", n_evals = 100)
  )

tuner = tnr("grid_search", resolution = 20, batch_size = 20)

tuner$optimize(instance)

as.data.table(instance$archive)[, list(cost, gamma, classif.ce)]
as.data.table(instance$archive, measures = msrs(c("classif.fpr", "classif.fnr")))[,list(cost, gamma, classif.ce, classif.fpr, classif.fnr)]
autoplot(instance, type = "surface")

svm_tuned = lrn("classif.svm", id = "SVM Tuned")
svm_tuned$param_set$values = instance$result_learner_param_vals
svm_tuned$train(task_tuning)
predictions = svm_tuned$predict(task_test)

#Predictions :
classif.auc      classif.bbrier   classif.bacc 
           NaN            NaN            0.5 

As you can see, the SVM is as performant as a random predictor... How could I tune it better ? Or is it a problem from my parameters settings ?

0

There are 0 best solutions below