I've discovered that there seems to be some odd behavior when calling waic
or loo
with do.call
. The result is much, much larger in terms of memory usage. (My overall goal is to be able to pass an arbitrary number of lists of models to a function, compare the models within each list using waic
/loo
, and return the result as a list. Note that if I use lapply
, it applies waic
to each model individually, which results in no comparisons, which isn't what I want.) As an example:
library(brms)
d <- data.frame(x = rnorm(100))
d$y <- rnorm(100)
d$z <- rnorm(100)
m1 <- brm(formula = y ~ x, data = d)
m2 <- brm(formula = y ~ z, data = d)
w1 <- waic(m1, m2)
w2 <- do.call(waic, list(m1, m2))
class(w1)
class(w2)
object.size(w1)
object.size(w2)
This shows me the output:
> class(w1)
[1] "loolist"
> class(w2)
[1] "loolist"
> object.size(w1)
15440 bytes
> object.size(w2)
11138384 bytes
The same thing happens if I use loo
, so it's not specify to the waic
function.
If I look at the results of str
, it seems like l2
is storing the full models, not just the loo
results, which is probably what's leading to the bloat. str(w1)
produces:
List of 3
$ loos :List of 2
..$ m1:List of 8
.. ..$ estimates : num [1:3, 1:2] -152.802 2.757 305.604 6.602 0.457 ...
.. .. ..- attr(*, "dimnames")=List of 2
.. .. .. ..$ : chr [1:3] "elpd_waic" "p_waic" "waic"
.. .. .. ..$ : chr [1:2] "Estimate" "SE"
.. ..$ pointwise : num [1:100, 1:3] -1.11 -1.08 -1.04 -2.05 -1.97 ...
.. .. ..- attr(*, "dimnames")=List of 2
.. .. .. ..$ : NULL
.. .. .. ..$ : chr [1:3] "elpd_waic" "p_waic" "waic"
.. ..$ elpd_waic : num -153
.. ..$ p_waic : num 2.76
.. ..$ waic : num 306
.. ..$ se_elpd_waic: num 6.6
.. ..$ se_p_waic : num 0.457
.. ..$ se_waic : num 13.2
.. ..- attr(*, "dims")= int [1:2] 4000 100
.. ..- attr(*, "class")= chr [1:2] "waic" "loo"
.. ..- attr(*, "yhash")= chr "543895fddfbc9e53e6fdf43fd5d57f2d95a3893f"
.. ..- attr(*, "model_name")= chr "m1"
..$ m2:List of 8
.. ..$ estimates : num [1:3, 1:2] -153.55 2.958 307.1 6.72 0.592 ...
.. .. ..- attr(*, "dimnames")=List of 2
.. .. .. ..$ : chr [1:3] "elpd_waic" "p_waic" "waic"
.. .. .. ..$ : chr [1:2] "Estimate" "SE"
.. ..$ pointwise : num [1:100, 1:3] -1.05 -1.12 -1.04 -1.95 -1.89 ...
.. .. ..- attr(*, "dimnames")=List of 2
.. .. .. ..$ : NULL
.. .. .. ..$ : chr [1:3] "elpd_waic" "p_waic" "waic"
.. ..$ elpd_waic : num -154
.. ..$ p_waic : num 2.96
.. ..$ waic : num 307
.. ..$ se_elpd_waic: num 6.72
.. ..$ se_p_waic : num 0.592
.. ..$ se_waic : num 13.4
.. ..- attr(*, "dims")= int [1:2] 4000 100
.. ..- attr(*, "class")= chr [1:2] "waic" "loo"
.. ..- attr(*, "yhash")= chr "543895fddfbc9e53e6fdf43fd5d57f2d95a3893f"
.. ..- attr(*, "model_name")= chr "m2"
$ diffs : 'compare.loo' num [1:2, 1:8] 0 -0.748 0 1.384 -152.802 ...
..- attr(*, "dimnames")=List of 2
.. ..$ : chr [1:2] "m1" "m2"
.. ..$ : chr [1:8] "elpd_diff" "se_diff" "elpd_waic" "se_elpd_waic" ...
$ ic_diffs__: num [1, 1:2] -1.5 2.77
..- attr(*, "dimnames")=List of 2
.. ..$ : chr "m1 - m2"
.. ..$ : chr [1:2] "WAIC" "SE"
- attr(*, "class")= chr "loolist"
str(w2)
produces something that starts like this (I've formatted it a bit so it's not all on one line):
List of 3
$ loos :List of 2
..$ structure(list(formula = structure(list(formula = y ~ x, pforms = list(
), pfix = list(
), resp = "y", family = structure(list(family = "gaussian", link = "identity", linkfun = function (mu) link(mu, link = slink
), linkinv = function (eta) ilink(eta, link = slink
), dpars = c("mu", "sigma"
), type = "real", ybounds = c(-Inf, Inf
), closed = c(NA, NA
), ad = c("weights", "subset", "se", "cens", "trunc", "mi"
), specials = c("residuals", "rescor")
), class = c("brmsfamily", "family")
), mecor = TRUE
), class = c("brmsformula", "bform")
), data = structure(list(y = c(0.222865037574825, 0.570128066864863, -0.135685571969938, 1.51400901767955, -1.42294537600286, 3.05953211020977, 1.77319875465707, -0.560055434573721, -2.17233843283651, 0.244891176968169, 0.852334318135536, -0.600487439599353, 0.37978067519301, -2.39510231005513, 0.595899684607949, 1.03714876676312, -1.41064279609684, -0.233463367708363, 1.83336396424474, 0.265639715825803, 0.896216012007527, -0.511005506576839, -0.990258505999856, -0.558980878174761, 1.05346814525893, -0.0899664221908207, -1.33110641888477, -0.188959731751976, -0.679800813713052, -0.321769958513427, 0.668316298445943, -0.427812609325669, -0.961094772216214, -0.194752817585835, 0.296422206866777, 0.274271008303485, -0.723313689099733, 1.39275639468154, -1.91894463948922, 0.322229704542086, -0.840174756057233, 1.46068495897767, -2.20030669514299, 0.766461853640603, -0.932998510936173, -0.389747230010412, 1.21222512351531, -0.367388491045232, 1.4680382011575, -0.248155597915535, 0.881321799678558, 0.928975863372169, -0.40097077944702, -0.932753769947884, -0.0787743622003075, 0.934773284552975, 0.802022931150372, 0.539236485182244, 1.90920483698261, 0.0959007938340177, 1.90355811256047, 1.23880369756146, 0.399161008234409, -1.61444070712025, 1.09228096133471, -0.334304984354992, 1.94919630176386, 0.388723697364589, -0.676560429449814, -0.486799514926963, -1.69814420604508, 0.610643539656719, -1.7358003804893, -1.18947570640705, -0.2048251685502, -0.244308782100457, 0.680149129765254, -0.194941685094996, 1.2495923072767, -0.508497293855665, -2.01720057673334, 0.717582169549046, -0.974164402299986, 0.191975856490787, -0.47647793799195, 1.30995888921046, 0.612325939317696, 1.68149269824028, 0.186230025956939, 1.4523083782372, 0.247246158932475, -0.749612962281774, 1.68680561431638, -1.4512951423047, -1.93333213179369, 1.20046240239089, -0.13573165368954, -0.545721678032106, -0.0145610725616504, -0.268693593357222
), x = c(1.55551806854814, -2.30450038318989, 0.296995230680729, 0.668154099392684, -0.1547006917081, -0.309984806283829, 0.495221682683779, -0.562225239149049, 0.347241974318799, -0.124006626810845, -1.44831536829081, 0.00816231155818208, 1.19269682091828, 1.09657509945784, -0.635792796784607, 0.387962591093174, 0.0706000447502193, -1.00525115255797, -0.289935304272473, -0.473562989522596, -0.783401194401419, -0.847650167792549, 0.275627321425429, -0.135952419089467, -0.77243964938835, 1.1657114904043, -1.01856242301745, -0.463154535428835, -0.586387548561369, 0.286936164713612, -0.940443323746505, 0.340709452808652, -0.894584173330604, 0.591347212942816, 0.949953660301979, 0.321956819738089, 0.580290486749746, -0.68755984894602, 1.46867714260446, 0.535535223822069, 0.80492806522492, 1.32291579645674, -1.08656956060417, 1.31416520144562, 1.48015770667234, 0.977243130228241, 0.0222292740372709, 1.39536233198176, 2.37899133963515, 2.55012059548256, -0.601237284663627, -1.10848718399703, -0.718810050747427, 1.0122041455424, -0.582118731616551, 0.0958348014993671, -0.983464989091578, 0.0985791907557584, 0.239011227739015, 0.172632248363633, -0.758182492239057, -1.48857289853501, 0.0685537104701556, 0.51044243736898, 0.482978773054692, 0.394488590480919, 0.608583888747901, -0.535321605973326, 0.617601681089166, -0.142879029528284, -0.727413272882995, -1.14984815574316, -0.34637735244817, 0.784041932914986, 1.02247745933603, 0.512966807729552, -1.43056324735507, 0.649198431445028, 0.783348344111711, -1.44908235653234, 0.216245543136908, -0.324521163713685, 1.25515815254712, -0.0190663771221899, -0.543757718828948, -0.193192758506608, 1.11354415805662, 0.484296073217596, -0.00441288006210047, -1.59834412372779, -0.463597727522575, 0.95437185173169, -0.00258922453282936, -0.595437835005253, 1.1527271628746, -0.00240938398175376, -0.733006155939103, -0.0418069102112059, -0.97696234693589, -0.913221059230367)
), class = "data.frame", row.names = c(NA, 100L
), terms = y ~ y + x, data_name = "d"
), prior = structure(list( prior = c("", "", "student_t(3, 0, 2.5)", "student_t(3, 0, 2.5)"
), class = c("b", "b", "Intercept", "sigma"
), coef = c("", "x", "", ""
), group = c("", "", "", ""
), resp = c("", "", "", ""
), dpar = c("", "", "", ""
), nlpar = c("", "", "", ""
), bound = c("", "", "", "")
), special = list(mu = list()
), row.names = c(NA, -4L
), class = c("brmsprior", "data.frame"
), sample_prior = "no"
), data2 = list(
), stanvars = structure(list(
), class = "stanvars"
), model = structure("// generated with brms 2.13.5\nfunctions {\n}\ndata {\n int<lower=1> N; // number of observations\n vector[N] Y; // response variable\n int<lower=1> K; // number of population-level effects\n matrix[N, K] X; // population-level design matrix\n int prior_only; // should the likelihood be ignored?\n}\ntransformed data {\n int Kc = K - 1;\n matrix[N, Kc] Xc; // centered version of X without an intercept\n vector[Kc] means_X; // column means of X before centering\n for (i in 2:K) {\n means_X[i - 1] = mean(X[, i]);\n Xc[, i - 1] = X[, i] - means_X[i - 1];\n }\n}\nparameters {\n vector[Kc] b; // population-level effects\n real Intercept; // temporary intercept for centered predictors\n real<lower=0> sigma; // residual SD\n}\ntransformed parameters {\n}\nmodel {\n // priors including all constants\n target += student_t_lpdf(Intercept | 3, 0, 2.5);\n target += student_t_lpdf(sigma | 3, 0, 2.5)\n - 1 * student_t_lccdf(0 | 3, 0, 2.5);\n // likelihood including all constants\n if (!prior_only) {\n target += normal_id_glm_lpdf(Y | Xc, Intercept, b, sigma);\n }\n}\ngenerated quantities {\n // actual population-level intercept\n real b_Intercept = Intercept - dot_product(means_X, b);\n}\n", class = c("character", "brmsmodel")
), ranef = structure(list(id = numeric(0
), group = character(0
), gn = numeric(0
), ...
(This continues for many, many lines.) Trying to print out anything else hangs R and I have to force quit.
It seems to me like do.call
is not unpacking the elements of the list like I'd hoped. Is there a way to fix this?
Edit:
I was able to use MrFlick's comment to figure out how to do what I wanted with this, but I'm still curious about the difference in that method vs. just using do.call
as above, so I'll leave the question here.
For the curious, the solution is:
foo <- function(...){
dots <- eval(substitute(alist(...)))
res <- list()
for(i in 1:length(dots)){
res[[i]] <- as.character(dots[[i]])[-1]
res[[i]] <- lapply(res[[i]], FUN = function(x) eval(parse(text = paste0('quote(', x, ')'))))
res[[i]] <- do.call('waic', res[[i]])
}
return(res)
}
If you call it like this x <- foo(list(m1, m2), list(m1, m3))
, the result is a list whose i-th element contains the results of applying waic
to the models in the i-th list passed to foo
.