How to compute two-way partial dependence (ie, how to marginalize a model over multiple variables)?

31 Views Asked by At

I'd like to understand the computation of a 2-way partial dependence. In this case I'm working with a random forest. Toy model:

library(ggplot2) # access the mpg data set
mod <- ranger::ranger(formula=year ~ cyl + cty + displ,
                      data=mpg, probability = TRUE)

Background: Computing partial dependence for a single variable:

To compute the partial for one variable, say cty, I would define values over the support

grid_size <- 10
cty_vals <- seq(min(mpg$cty), max(mpg$cty), length.out = grid_size)

For each of those values, predict to a dataset where all observations have that value for cty.

pd_dat <- do.call(rbind, lapply(cty_vals, function(cty_val){
   d <- mpg
   d$cty <- cty_val
   return(d)
 })
)

predict(mod, pd_dat)$predictions

To get the final PD I would take the average prediction for each cty_val (not shown).

Objective: Compute multi-way partial dependence: which method?

Let's say I now want to compute the two-way partial dependence, $PD(cyl, displ)$. Define the grid for each var:

cty_vals <- seq(min(mpg$cty), max(mpg$cty), length.out = grid_size)
displ_vals <- seq(min(mpg$displ), max(mpg$displ), length.out = grid_size)

From here I see two options.

  1. Predict to a dataset that contains the exhaustive combinations of cty and displ

Setup option 1:

combos <- expand.grid(cty_vals, displ_vals)
names(combos) <- c('cty', 'displ')
  1. Predict to a dataset that contains arbitrary combinations of cty and displ

Setup option 2:

combos <- data.frame(cty_vals, displ_vals)
names(combos) <- c('cty', 'displ')

The resulting data frame on which to compute predictions:


pd_dat <- do.call(rbind, lapply(1:nrow(combos), function(i){
     d <- mpg
     d$cty <- combos$cty[i]
     d$displ <- combos$displ[i]
     return(d)
}))

In either case I suppose the final result will summarize the predictions over each combination of cyl and displ.

library(dplyr)

preds <- cbind(combos, predict(mod, pd_dat)$predictions)
names(preds)[3:4] <- sprintf('prob_%s', mod$forest$class.values )

preds %>%
  group_by(cty, displ) %>%
  summarize(partial_2way_1999 = mean(prob_1999),
            partial_2way_2008 = mean(prob_2008))

My intuition is toward option 1 (predict to exhaustive combinations of both variables).

0

There are 0 best solutions below