I am trying to build a survxai explainer from a survival model built with mlr3proba. I'm having trouble creating the predict_function necessary for the explainer. Has anyone ever tried to build something like this?
So far, my code is the following:
require(survxai)
require(survival)
require(survivalmodels)
require(mlr3proba)
require(mlr3pipelines)
create_pipeops <- function(learner) {
GraphLearner$new(po("encode") %>>% po("scale") %>>% po("learner", learner))
}
fit<-lrn("surv.deepsurv")
fit<-create_pipeops(fit)
data<-veteran
survival_task<-TaskSurv$new("veteran", veteran, time = "time", event = "status")
fit$train(survival_task)
predict_function<-function(model, newdata, times=NULL){
if(!is.data.frame(newdata)){
newdata <- data.frame(newdata)
}
surv_task<-TaskSurv$new("task", newdata, time = "time",
event = "status")
pred<-model$predict(surv_task)
mat<-matrix(pred$data$distr, nrow = nrow(pred$data$distr))
colnames(mat)<-colnames(pred$data$distr)
return(mat)
}
explainer<-survxai::explain(model = learner$model, data = veteran[,-c(3,4)],
y = Surv(veteran$time, veteran$status),
predict_function = predict_function)
pred_breakdown<-prediction_breakdown(explainer, veteran[1,])
It throws the following error: Error in [.data.table(r6_private(backend)$.data, , event, with = FALSE) :
column(s) not found: status, but I suspect that once that one is solved there may be more. I don't fully understand the structure of the object that the function returns.
In the predict_function, I included the times argument because according to the R help page, the function must take the three arguments.
Working example with randomForestSRC here, you can just change
surv.rfsrctosurv.deepsurvfor your example. BTW we are planning on implementing this within mlr3proba soon, or I might just add it directly to survivalmodels, still deciding!Created on 2022-06-07 by the reprex package (v2.0.1)