Specify options for final model only with caret

1.1k Views Asked by At

Context

I am using caret to fit and tune models. Typically, the best parameters are found using a resampling method such as cross-validation. Once the best parameters are chosen, a final model is fitted to the whole training data using the best set of parameters.

In addition to the parameters to tune (passed via tuneGrid), one can pass arguments to the underlying algorithm being called by passing them to train.

My question

Is there any way to specify model-specific options to be used for the final model only?

For extra clarity: I do want to fit all the intermediate models (to obtain a reliable performance estimate) but I want to fit the final model with different arguments (in addition to the best parameters).

Specific use case

Let's say I want to fit a bartMachine to some data and then use the final model in production. I would typically save the tuned model to disk and load it as needed. But I can only save/load a bartMachine model that has been serialized, i.e. I need to pass serialize=T to bartMachine via caret::train.

But that will serialize all the models which is very impractical. I really only need to serialize the final model. Is there any way to do that?

library("caret")
library("bartMachine")
tgrid <- expand.grid(num_trees = 100,
                       k = c(2, 3),
                       alpha = 0.95, 
                       beta = 2,
                       nu =  3)
# The printed log shows that all intermediate models are being serialized
fit <- train(hp ~ ., 
             data=mtcars, 
             method="bartMachine",
             serialize=T,
             tuneGrid=tgrid,
             trControl = trainControl(method="cv", 5, verboseIter=T))
1

There are 1 best solutions below

8
On

To fit models to the entire data set without parameter tuning or resampling modify the train control method to none:

tgrid <- expand.grid(num_trees = 100,
                     k = 2,
                     alpha = 0.95, 
                     beta = 2,
                     nu =  3)
fit <- train(hp ~ ., 
             data=mtcars, 
             method="bartMachine",
             serialize=TRUE,
             tuneGrid=tgrid,
             trControl = trainControl(method="none"))

Note, that I have removed one of the two k values in the question code. Otherwise there is an error: Only one model should be specified in tuneGrid with no resampling. I suggest building a separate model with the other k value.

The code above gives the following output:

bartMachine initializing with 100 trees...
bartMachine vars checked...
bartMachine java init...
bartMachine factors created...
bartMachine before preprocess...
bartMachine after preprocess... 11 total features...
bartMachine sigsq estimated...
bartMachine training data finalized...
Now building bartMachine for regression ...
building BART with mem-cache speedup...
Iteration 100/1250  mem: 17.6/477.1MB
Iteration 200/1250  mem: 25.1/477.1MB
Iteration 300/1250  mem: 30.8/477.1MB
Iteration 400/1250  mem: 39.9/477.1MB
Iteration 500/1250  mem: 19/477.1MB
Iteration 600/1250  mem: 59.6/477.1MB
Iteration 700/1250  mem: 39.6/477.1MB
Iteration 800/1250  mem: 79.8/477.1MB
Iteration 900/1250  mem: 119.9/477.1MB
Iteration 1000/1250  mem: 40.7/477.1MB
Iteration 1100/1250  mem: 80.8/477.1MB
Iteration 1200/1250  mem: 121/477.1MB
done building BART in 1.289 sec 

burning and aggregating chains from all threads... done
evaluating in sample data...done
serializing in order to be saved for future R sessions...done

The serialize parameter is set to TRUE in fit$finalModel:

fit$finalModel$serialize
[1] TRUE

For what it's worth, the bartMachine internal check_serialization function does not give any warnings or errors (or any other output):

bartMachine:::check_serialization(fit$finalModel)

It's not clear to me how to extract the serialized object from fit$finalModel. I presume it is stored in fit$finalModel$java_bart_machine which contains an rJava pointer. It may be possible to gain further insight using the rJava package which bartMachine depends on.

Update: @antoine-sac states in the comments below "serialize=T does not cause the model to be saved but serialises the samples into the model, which means they are saved when the model is written to disk".