R: Why is using do.call here giving me much bigger results?

49 Views Asked by At

I've discovered that there seems to be some odd behavior when calling waic or loo with do.call. The result is much, much larger in terms of memory usage. (My overall goal is to be able to pass an arbitrary number of lists of models to a function, compare the models within each list using waic/loo, and return the result as a list. Note that if I use lapply, it applies waic to each model individually, which results in no comparisons, which isn't what I want.) As an example:

library(brms)

d <- data.frame(x = rnorm(100))
d$y <- rnorm(100)
d$z <- rnorm(100)

m1 <- brm(formula = y ~ x, data = d)
m2 <- brm(formula = y ~ z, data = d)

w1 <- waic(m1, m2)
w2 <- do.call(waic, list(m1, m2))

class(w1)
class(w2)

object.size(w1)
object.size(w2)

This shows me the output:

> class(w1)
[1] "loolist"
> class(w2)
[1] "loolist"
> object.size(w1)
15440 bytes
> object.size(w2)
11138384 bytes

The same thing happens if I use loo, so it's not specify to the waic function.

If I look at the results of str, it seems like l2 is storing the full models, not just the loo results, which is probably what's leading to the bloat. str(w1) produces:

List of 3
 $ loos      :List of 2
  ..$ m1:List of 8
  .. ..$ estimates   : num [1:3, 1:2] -152.802 2.757 305.604 6.602 0.457 ...
  .. .. ..- attr(*, "dimnames")=List of 2
  .. .. .. ..$ : chr [1:3] "elpd_waic" "p_waic" "waic"
  .. .. .. ..$ : chr [1:2] "Estimate" "SE"
  .. ..$ pointwise   : num [1:100, 1:3] -1.11 -1.08 -1.04 -2.05 -1.97 ...
  .. .. ..- attr(*, "dimnames")=List of 2
  .. .. .. ..$ : NULL
  .. .. .. ..$ : chr [1:3] "elpd_waic" "p_waic" "waic"
  .. ..$ elpd_waic   : num -153
  .. ..$ p_waic      : num 2.76
  .. ..$ waic        : num 306
  .. ..$ se_elpd_waic: num 6.6
  .. ..$ se_p_waic   : num 0.457
  .. ..$ se_waic     : num 13.2
  .. ..- attr(*, "dims")= int [1:2] 4000 100
  .. ..- attr(*, "class")= chr [1:2] "waic" "loo"
  .. ..- attr(*, "yhash")= chr "543895fddfbc9e53e6fdf43fd5d57f2d95a3893f"
  .. ..- attr(*, "model_name")= chr "m1"
  ..$ m2:List of 8
  .. ..$ estimates   : num [1:3, 1:2] -153.55 2.958 307.1 6.72 0.592 ...
  .. .. ..- attr(*, "dimnames")=List of 2
  .. .. .. ..$ : chr [1:3] "elpd_waic" "p_waic" "waic"
  .. .. .. ..$ : chr [1:2] "Estimate" "SE"
  .. ..$ pointwise   : num [1:100, 1:3] -1.05 -1.12 -1.04 -1.95 -1.89 ...
  .. .. ..- attr(*, "dimnames")=List of 2
  .. .. .. ..$ : NULL
  .. .. .. ..$ : chr [1:3] "elpd_waic" "p_waic" "waic"
  .. ..$ elpd_waic   : num -154
  .. ..$ p_waic      : num 2.96
  .. ..$ waic        : num 307
  .. ..$ se_elpd_waic: num 6.72
  .. ..$ se_p_waic   : num 0.592
  .. ..$ se_waic     : num 13.4
  .. ..- attr(*, "dims")= int [1:2] 4000 100
  .. ..- attr(*, "class")= chr [1:2] "waic" "loo"
  .. ..- attr(*, "yhash")= chr "543895fddfbc9e53e6fdf43fd5d57f2d95a3893f"
  .. ..- attr(*, "model_name")= chr "m2"
 $ diffs     : 'compare.loo' num [1:2, 1:8] 0 -0.748 0 1.384 -152.802 ...
  ..- attr(*, "dimnames")=List of 2
  .. ..$ : chr [1:2] "m1" "m2"
  .. ..$ : chr [1:8] "elpd_diff" "se_diff" "elpd_waic" "se_elpd_waic" ...
 $ ic_diffs__: num [1, 1:2] -1.5 2.77
  ..- attr(*, "dimnames")=List of 2
  .. ..$ : chr "m1 - m2"
  .. ..$ : chr [1:2] "WAIC" "SE"
 - attr(*, "class")= chr "loolist"

str(w2) produces something that starts like this (I've formatted it a bit so it's not all on one line):

List of 3
 $ loos      :List of 2
  ..$ structure(list(formula = structure(list(formula = y ~ x, pforms = list(
),     pfix = list(
), resp = "y", family = structure(list(family = "gaussian",         link = "identity", linkfun = function (mu)         link(mu, link = slink
), linkinv = function (eta)         ilink(eta, link = slink
), dpars = c("mu", "sigma"
), type = "real",         ybounds = c(-Inf, Inf
), closed = c(NA, NA
), ad = c("weights",         "subset", "se", "cens", "trunc", "mi"
), specials = c("residuals",         "rescor")
), class = c("brmsfamily", "family")
), mecor = TRUE
), class = c("brmsformula", "bform")
), data = structure(list(y = c(0.222865037574825, 0.570128066864863, -0.135685571969938, 1.51400901767955, -1.42294537600286, 3.05953211020977, 1.77319875465707, -0.560055434573721, -2.17233843283651, 0.244891176968169, 0.852334318135536, -0.600487439599353, 0.37978067519301, -2.39510231005513, 0.595899684607949, 1.03714876676312, -1.41064279609684, -0.233463367708363, 1.83336396424474, 0.265639715825803, 0.896216012007527, -0.511005506576839, -0.990258505999856, -0.558980878174761, 1.05346814525893, -0.0899664221908207, -1.33110641888477, -0.188959731751976, -0.679800813713052, -0.321769958513427, 0.668316298445943, -0.427812609325669, -0.961094772216214, -0.194752817585835, 0.296422206866777, 0.274271008303485, -0.723313689099733, 1.39275639468154, -1.91894463948922, 0.322229704542086, -0.840174756057233, 1.46068495897767, -2.20030669514299, 0.766461853640603, -0.932998510936173, -0.389747230010412, 1.21222512351531, -0.367388491045232, 1.4680382011575, -0.248155597915535, 0.881321799678558, 0.928975863372169, -0.40097077944702, -0.932753769947884, -0.0787743622003075, 0.934773284552975, 0.802022931150372, 0.539236485182244, 1.90920483698261, 0.0959007938340177, 1.90355811256047, 1.23880369756146, 0.399161008234409, -1.61444070712025, 1.09228096133471, -0.334304984354992, 1.94919630176386, 0.388723697364589, -0.676560429449814, -0.486799514926963, -1.69814420604508, 0.610643539656719, -1.7358003804893, -1.18947570640705, -0.2048251685502, -0.244308782100457, 0.680149129765254, -0.194941685094996, 1.2495923072767, -0.508497293855665, -2.01720057673334, 0.717582169549046, -0.974164402299986, 0.191975856490787, -0.47647793799195, 1.30995888921046, 0.612325939317696, 1.68149269824028, 0.186230025956939, 1.4523083782372, 0.247246158932475, -0.749612962281774, 1.68680561431638, -1.4512951423047, -1.93333213179369, 1.20046240239089, -0.13573165368954, -0.545721678032106, -0.0145610725616504, -0.268693593357222
), x = c(1.55551806854814, -2.30450038318989, 0.296995230680729, 0.668154099392684, -0.1547006917081, -0.309984806283829, 0.495221682683779, -0.562225239149049, 0.347241974318799, -0.124006626810845, -1.44831536829081, 0.00816231155818208, 1.19269682091828, 1.09657509945784, -0.635792796784607, 0.387962591093174, 0.0706000447502193, -1.00525115255797, -0.289935304272473, -0.473562989522596, -0.783401194401419, -0.847650167792549, 0.275627321425429, -0.135952419089467, -0.77243964938835, 1.1657114904043, -1.01856242301745, -0.463154535428835, -0.586387548561369, 0.286936164713612, -0.940443323746505, 0.340709452808652, -0.894584173330604, 0.591347212942816, 0.949953660301979, 0.321956819738089, 0.580290486749746, -0.68755984894602, 1.46867714260446, 0.535535223822069, 0.80492806522492, 1.32291579645674, -1.08656956060417, 1.31416520144562, 1.48015770667234, 0.977243130228241, 0.0222292740372709, 1.39536233198176, 2.37899133963515, 2.55012059548256, -0.601237284663627, -1.10848718399703, -0.718810050747427, 1.0122041455424, -0.582118731616551, 0.0958348014993671, -0.983464989091578, 0.0985791907557584, 0.239011227739015, 0.172632248363633, -0.758182492239057, -1.48857289853501, 0.0685537104701556, 0.51044243736898, 0.482978773054692, 0.394488590480919, 0.608583888747901, -0.535321605973326, 0.617601681089166, -0.142879029528284, -0.727413272882995, -1.14984815574316, -0.34637735244817, 0.784041932914986, 1.02247745933603, 0.512966807729552, -1.43056324735507, 0.649198431445028, 0.783348344111711, -1.44908235653234, 0.216245543136908, -0.324521163713685, 1.25515815254712, -0.0190663771221899, -0.543757718828948, -0.193192758506608, 1.11354415805662, 0.484296073217596, -0.00441288006210047, -1.59834412372779, -0.463597727522575, 0.95437185173169, -0.00258922453282936, -0.595437835005253, 1.1527271628746, -0.00240938398175376, -0.733006155939103, -0.0418069102112059, -0.97696234693589, -0.913221059230367)
), class = "data.frame", row.names = c(NA, 100L
), terms = y ~ y + x, data_name = "d"
), prior = structure(list(    prior = c("", "", "student_t(3, 0, 2.5)", "student_t(3, 0, 2.5)"    
), class = c("b", "b", "Intercept", "sigma"
), coef = c("",     "x", "", ""
), group = c("", "", "", ""
), resp = c("", "",     "", ""
), dpar = c("", "", "", ""
), nlpar = c("", "", "",     ""
), bound = c("", "", "", "")
), special = list(mu = list()
), row.names = c(NA, -4L
), class = c("brmsprior", "data.frame"
), sample_prior = "no"
),     data2 = list(
), stanvars = structure(list(
), class = "stanvars"
),     model = structure("// generated with brms 2.13.5\nfunctions {\n}\ndata {\n  int<lower=1> N;  // number of observations\n  vector[N] Y;  // response variable\n  int<lower=1> K;  // number of population-level effects\n  matrix[N, K] X;  // population-level design matrix\n  int prior_only;  // should the likelihood be ignored?\n}\ntransformed data {\n  int Kc = K - 1;\n  matrix[N, Kc] Xc;  // centered version of X without an intercept\n  vector[Kc] means_X;  // column means of X before centering\n  for (i in 2:K) {\n    means_X[i - 1] = mean(X[, i]);\n    Xc[, i - 1] = X[, i] - means_X[i - 1];\n  }\n}\nparameters {\n  vector[Kc] b;  // population-level effects\n  real Intercept;  // temporary intercept for centered predictors\n  real<lower=0> sigma;  // residual SD\n}\ntransformed parameters {\n}\nmodel {\n  // priors including all constants\n  target += student_t_lpdf(Intercept | 3, 0, 2.5);\n  target += student_t_lpdf(sigma | 3, 0, 2.5)\n    - 1 * student_t_lccdf(0 | 3, 0, 2.5);\n  // likelihood including all constants\n  if (!prior_only) {\n    target += normal_id_glm_lpdf(Y | Xc, Intercept, b, sigma);\n  }\n}\ngenerated quantities {\n  // actual population-level intercept\n  real b_Intercept = Intercept - dot_product(means_X, b);\n}\n", class = c("character",     "brmsmodel")
), ranef = structure(list(id = numeric(0
), group = character(0
),         gn = numeric(0
), ...

(This continues for many, many lines.) Trying to print out anything else hangs R and I have to force quit.

It seems to me like do.call is not unpacking the elements of the list like I'd hoped. Is there a way to fix this?


Edit: I was able to use MrFlick's comment to figure out how to do what I wanted with this, but I'm still curious about the difference in that method vs. just using do.call as above, so I'll leave the question here.

For the curious, the solution is:

foo <- function(...){
     dots <- eval(substitute(alist(...)))
     res <- list()
     for(i in 1:length(dots)){
          res[[i]] <- as.character(dots[[i]])[-1]
          res[[i]] <- lapply(res[[i]], FUN = function(x) eval(parse(text = paste0('quote(', x, ')'))))
          res[[i]] <- do.call('waic', res[[i]])
     }
    return(res)
}

If you call it like this x <- foo(list(m1, m2), list(m1, m3)), the result is a list whose i-th element contains the results of applying waic to the models in the i-th list passed to foo.

0

There are 0 best solutions below