quantForestError() predictions on raster stack?

42 Views Asked by At

Do any clever people out there know how to code a quantForestError() prediction on a raster stack instead of a matrix or data.frame? For example...

library(ranger)
library(tidyverse)
library(raster)
library(forestError)
data(iris)
glimpse(iris)

mod<-ranger(Sepal.Length~. -Species,iris,keep.inbag =TRUE,seed=4)#random forest model

#trying to get prediction, and upper and lower prediction interval
#using quantForestError(model,x-training data, new data, y-training data, what I want)$which thing as output
quantForestError(mod,iris[,2:4],iris[10,2:4],iris[,1],what=c("interval"))$pred
#4.823236
quantForestError(mod,iris[,2:4],iris[10,2:4],iris[,1],what=c("interval"))$upper_0.05
#5.091271
quantForestError(mod,iris[,2:4],iris[10,2:4],iris[,1],what=c("interval"))$lower_0.05
#4.293341

#generate raster data
sw <- raster(ncol=10, nrow=10)
set.seed(4)
values(sw) <- runif(ncell(sw),min=iris$Sepal.Width,max=iris$Sepal.Width)
pl <- raster(ncol=10, nrow=10)
set.seed(4)
values(pl) <- runif(ncell(pl),min=iris$Petal.Length,max=iris$Petal.Length)
pw <- raster(ncol=10, nrow=10)
set.seed(4)
values(pw) <- runif(ncell(pw),min=iris$Petal.Width,max=iris$Petal.Width)
all<-stack(sw,pl,pw)
names(all)<-c("Sepal.Width","Petal.Length","Petal.Width")
plot(all)

##try on raster --> error
quantForestError(mod,iris[,2:4],all,iris[,1],what=c("interval"))$pred
#Error in checkXtrainXtest(X.train, X.test) : 
#'X.test' must be a matrix or data.frame of dimension 2


#for what it's worth this is how I would get prediction on raster normally with ranger
plot(predict(all, mod, type='se', seed=4, fun = function(model, ...) predict(model, ...)$predictions))
plot(predict(all, mod, type='se', seed=4, fun = function(model, ...) predict(model, ...)$se))

I do understand I could use the se predictions from ranger prediction to get confidence intervals but see the problem here. I think what I need to do is convert the raster to a matrix, make predictions, and convert back to raster, but I'm unsure how to do this. Thanks for any input you have.

2

There are 2 best solutions below

2
Robert Hijmans On BEST ANSWER

I show how to do that with the "terra" package that has replaced "raster".

You can use terra::predict. You need to name the additional arguments and you need to supply an argument to fun as you want to use quantForestError instead of the predict method from "ranger". To only get "pred" you need to supply a wrapper around quantForestError.

Example data

library(terra)
sw <- pl <- pw <- rast(ncol=10, nrow=10)
set.seed(4)
values(sw) <- runif(ncell(sw), min=iris$Sepal.Width, max=iris$Sepal.Width)
values(pl) <- runif(ncell(pl), min=iris$Petal.Length,max=iris$Petal.Length)
values(pw) <- runif(ncell(pw), min=iris$Petal.Width, max=iris$Petal.Width)
x <- c(sw, pl, pw)
x[1] <- NA
names(x) <- c("Sepal.Width","Petal.Length","Petal.Width")

Model and predictions

library(forestError)
mod <- ranger::ranger(Sepal.Length~. -Species, iris, keep.inbag=TRUE, seed=4)

f <- \(...) quantForestError(...)$pred
p <- predict(x, mod, fun=f, X.train=iris[,2:4], Y.train=iris[,1], 
             what=c("interval"), na.rm=TRUE, wopt=list(names="qFerror"))

p
#class       : SpatRaster 
#dimensions  : 10, 10, 1  (nrow, ncol, nlyr)
#resolution  : 36, 18  (x, y)
#extent      : -180, 180, -90, 90  (xmin, xmax, ymin, ymax)
#coord. ref. : lon/lat WGS 84 (CRS84) (OGC:CRS84) 
#source(s)   : memory
#name        :  qFerror 
#min value   : 4.600881 
vmax value   : 6.653509 
0
Kevin On

Here's a method for converting the raster stack to data.frame, calculating predictions and prediction intervals, and then converting back to raster. It is very slow but seems to perform better than calculating standard errors with the ranger package.

all2<-as.data.frame(all,xy=TRUE)%>%drop_na()
colnames(all2)
out<-quantForestError(mod,iris[,2:4],all2[,3:5],iris[,1],what=c("interval"))
pred<-cbind(all2[,1:2],out[,1])
lower<-cbind(all2[,1:2],out[,2])
upper<-cbind(all2[,1:2],out[,3])
pred_rst <- rasterFromXYZ(pred,crs=crs(all), res=res(all))
pred_lwr <- rasterFromXYZ(lower,crs=crs(all), res=res(all))
pred_upr <- rasterFromXYZ(upper,crs=crs(all), res=res(all))
plot(stack(pred_lwr,pred_rst,pred_upr))