How to terminate a currently long running XGBoost CV process in R shiny using a Button?

150 Views Asked by At

I would like to implement a cross-validation model in R Shiny using the xgboost model and the xgb.cv() function.

Taking into account that this process/function will take a couple of hours to be completed, I would like to add a "Cancel" button which will be implemented with a stop process function in order the user to terminate the process at any time.

Could you please advise me on how to proceed?

Server Code:

server <-  function(input, output, session) {

  observeEvent(input$ML_Submit_Button, {
    shinyjs::hide("ML_Submit_Button")
    shinyjs::show("ML_Stop_Button")
    
    xgb_gs_cv_regression(
      xgb_train = values$xgb_train,
      subsample_choice = values$subsample_slider_seq,
      colsample_bytree_choice = values$colsample_bytree_slider_seq,
      max_depth_choice = values$max_depth_slider_seq,
      min_child_weight_choice = values$min_child_weight_slider_seq,
      eta_choice = values$eta_slider_seq,
      n_rounds_choice = values$n_rounds_slider_seq,
      n_fold_choice = values$n_fold_slider_seq
    )
    
    shinyjs::hide("ML_Stop_Button")
    shinyjs::show("ML_Submit_Button")
    
    
  })

}

XGB CV Function Code:

xgb_gs_cv_regression <- function(xgb_train,
           subsample_choice,
           colsample_bytree_choice,
           max_depth_choice,
           min_child_weight_choice,
           eta_choice,
           n_rounds_choice,
           n_fold_choice) {

searchGridSubCol <- expand.grid(
  subsample = subsample_choice,
  colsample_bytree = colsample_bytree_choice,
  max_depth = max_depth_choice,
  min_child_weight = min_child_weight_choice,
  eta = eta_choice,
  n_rounds = n_rounds_choice,
  n_fold = n_fold_choice
)

rmseErrorsHyperparameters <- apply(searchGridSubCol, 1,
                                   
                                   function(parameterList) {
                                     #Extract Parameters to test
                                     currentSubsampleRate <-
                                       parameterList[["subsample"]]
                                     currentColsampleRate <-
                                       parameterList[["colsample_bytree"]]
                                     currentDepth <-
                                       parameterList[["max_depth"]]
                                     currentEta <-
                                       parameterList[["eta"]]
                                     currentMinChildWeight <-
                                       parameterList[["min_child_weight"]]
                                     currentNRounds <-
                                       parameterList[["n_rounds"]]
                                     currentNFold <-
                                       parameterList[["n_fold"]]
                                     
                                     xgboostModelCV <-
                                       xgb.cv(
                                         objective = "reg:squarederror",
                                         data =  xgb_train,
                                         booster = "gbtree",
                                         showsd = TRUE,
                                         #metrics = "rmse",
                                         verbose = TRUE,
                                         print_every_n = 10,
                                         early_stopping_rounds = 10,
                                         eval_metric = "rmse",
                                         "nrounds" = currentNRounds,
                                         "nfold" = currentNFold,
                                         "max_depth" = currentDepth,
                                         "eta" = currentEta,
                                         "subsample" = currentSubsampleRate,
                                         "colsample_bytree" = currentColsampleRate,
                                         "min_child_weight" = currentMinChildWeight
                                       )
                                     
                                     xgb_cv_xvalidationScores <-
                                       xgboostModelCV$evaluation_log
                                     
                                     test_rmse <-
                                       tail(xgb_cv_xvalidationScores$test_rmse_mean, 1)
                                     train_rmse <-
                                       tail(xgb_cv_xvalidationScores$train_rmse_mean, 1)
                                     
                                     gs_results_output <-
                                       c(
                                         test_rmse,
                                         train_rmse,
                                         currentSubsampleRate,
                                         currentColsampleRate,
                                         currentDepth,
                                         currentEta,
                                         currentMinChildWeight,
                                         currentNRounds,
                                         currentNFold
                                       )
                                     
                                     return(gs_results_output)
                                     
                                   })

gs_results_varnames <-
  c(
    "TestRMSE",
    "TrainRMSE",
    "SubSampRate",
    "ColSampRate",
    "Depth",
    "eta",
    "currentMinChildWeight",
    "nrounds",
    "nfold"
  )
t_rmseErrorsHyperparameters <-
  as.data.frame(t(rmseErrorsHyperparameters))
names(t_rmseErrorsHyperparameters) <- gs_results_varnames

return(t_rmseErrorsHyperparameters) 

}
1

There are 1 best solutions below

2
ismirsehregal On BEST ANSWER

You can realize the desired pattern via callr::r_bg(), which is based on processx::process().

r_bg() runs R functions in a background R process - which can be cancled via its kill() method.

Actually it doesn't matter which function you are running - so I simplified the example.

Please check the following:

library(shiny)
library(callr)
library(shinyjs)

long_running_function <- function(x){
  Sys.sleep(x)
  return(sprintf("I slept %s seconds", x))
}

ui <- fluidPage(
  useShinyjs(),
  actionButton("runbgp", "Run bg process"),
  actionButton("cancelbgp", "Cancel bg process")
)

server <- function(input, output, session) {
  rv <- reactiveValues(bg_process = NULL)
  
  observeEvent(input$runbgp, {
    disable("runbgp")
    enable("cancelbgp")
    rv$bg_process <- r_bg(long_running_function, args = list(5), stdout = "|", stderr = "2>&1")
  })
  
  observeEvent(input$cancelbgp, {
    enable("runbgp")
    disable("cancelbgp")
    cat(paste("Killing process - PID:", rv$bg_process$get_pid(), "\n"))
    rv$bg_process$kill()
  })
  
  observe({
    invalidateLater(1000)
    req(rv$bg_process)
    if(rv$bg_process$poll_io(0)[["process"]] == "ready") {
      enable("runbgp")
      disable("cancelbgp")
      print(rv$bg_process$get_result())
      rv$bg_process <- NULL
    }
  })
}

shinyApp(ui, server)