I'd like to understand the computation of a 2-way partial dependence. In this case I'm working with a random forest. Toy model:
library(ggplot2) # access the mpg data set
mod <- ranger::ranger(formula=year ~ cyl + cty + displ,
data=mpg, probability = TRUE)
Background: Computing partial dependence for a single variable:
To compute the partial for one variable, say cty, I would define values over the support
grid_size <- 10
cty_vals <- seq(min(mpg$cty), max(mpg$cty), length.out = grid_size)
For each of those values, predict to a dataset where all observations have that value for cty.
pd_dat <- do.call(rbind, lapply(cty_vals, function(cty_val){
d <- mpg
d$cty <- cty_val
return(d)
})
)
predict(mod, pd_dat)$predictions
To get the final PD I would take the average prediction for each cty_val (not shown).
Objective: Compute multi-way partial dependence: which method?
Let's say I now want to compute the two-way partial dependence, $PD(cyl, displ)$. Define the grid for each var:
cty_vals <- seq(min(mpg$cty), max(mpg$cty), length.out = grid_size)
displ_vals <- seq(min(mpg$displ), max(mpg$displ), length.out = grid_size)
From here I see two options.
- Predict to a dataset that contains the exhaustive combinations of
ctyanddispl
Setup option 1:
combos <- expand.grid(cty_vals, displ_vals)
names(combos) <- c('cty', 'displ')
- Predict to a dataset that contains arbitrary combinations of
ctyanddispl
Setup option 2:
combos <- data.frame(cty_vals, displ_vals)
names(combos) <- c('cty', 'displ')
The resulting data frame on which to compute predictions:
pd_dat <- do.call(rbind, lapply(1:nrow(combos), function(i){
d <- mpg
d$cty <- combos$cty[i]
d$displ <- combos$displ[i]
return(d)
}))
In either case I suppose the final result will summarize the predictions over each combination of cyl and displ.
library(dplyr)
preds <- cbind(combos, predict(mod, pd_dat)$predictions)
names(preds)[3:4] <- sprintf('prob_%s', mod$forest$class.values )
preds %>%
group_by(cty, displ) %>%
summarize(partial_2way_1999 = mean(prob_1999),
partial_2way_2008 = mean(prob_2008))
My intuition is toward option 1 (predict to exhaustive combinations of both variables).