Modify linetype of waterfall plot (shapviz package)

103 Views Asked by At

I am trying to modify the bar's borderline of sv_waterfall() of the shapviz package with no success. I want to customize my shap's waterfall plot by eliminating the linetype (by supllyinglinetype=none, but it fails miserably.

I found (in rdrr.io) that the developer of this package uses gggenes::geom_gene_arrow() under the hood (please correct if I am wrong) to produce their waterfall plots. I think that this type of customization is far beyond my understanding.

Here my code

shap_mlr <- fastshap::explain(
  MLR_N_All, 
  X = X_df_train,
  nsim = 1000, 
  pred_wrapper = predict,
  shap_only = FALSE,
  adjust = TRUE
)

sv_LM <- shapviz(shap_mlr)
sv_LM_wtf <- sv_waterfall(sv_LM, row_id = 3, fill_colors=c('#FF0051', '#008BFB')) +ggtitle("LM")+
                            theme(text=element_text(size=14,family="Palatino"),
                            axis.title.x=element_blank())
## + gggenes::geom_gene_arrow((aes(linetype = 0)) #fail
    
sv_LM_wtf

Any idea?

2

There are 2 best solutions below

5
Michael M On BEST ANSWER

Instead of removing the line, you can set the bar color to its fill color.

library(shapviz)
library(ggplot2)
library(xgboost)

set.seed(3653)

x <- c("carat", "cut", "color", "clarity")
dtrain <- xgb.DMatrix(data.matrix(diamonds[x]), label = diamonds$price)
fit <- xgb.train(params = list(learning_rate = 0.1), data = dtrain, nrounds = 65)

# Explanation data
dia_small <- diamonds[sample(nrow(diamonds), 2000), ]

shp <- shapviz(fit, X_pred = data.matrix(dia_small[x]), X = dia_small)

p <- sv_waterfall(shp, row_id = 1)
p

q <- ggplot_build(p)
q$data[[1]]$colour <- q$data[[1]]$fill
plot(ggplot_gtable(q))

enter image description here

2
jared_mamrot On

One potential solution is to set linetype to "NA", e.g. using the example from https://github.com/ModelOriented/shapviz:

# install.packages("shapviz")
library(shapviz)
library(ggplot2)
library(xgboost)

set.seed(3653)

x <- c("carat", "cut", "color", "clarity")
dtrain <- xgb.DMatrix(data.matrix(diamonds[x]), label = diamonds$price)
fit <- xgb.train(params = list(learning_rate = 0.1), data = dtrain, nrounds = 65)

# Explanation data
dia_small <- diamonds[sample(nrow(diamonds), 2000), ]

shp <- shapviz(fit, X_pred = data.matrix(dia_small[x]), X = dia_small)

p <- sv_waterfall(shp, row_id = 1)
p

# set linetype to NA
p$mapping$linetype <- NA
p

Created on 2023-12-06 with reprex v2.0.2

Is this your desired outcome? Or have I misunderstood?