Removing observations conditionally (after use of MatchIt package) in R

233 Views Asked by At

I have used the package MatchIt to conduct an exact matching for treatment (treat = 1) and control groups (treat = 0) -- the matching was made through age. The variable subclass reveals the matched units.

I would like to have one control unit selected randomly for each treated unit if it is matched to more than one control. It is important that it be random.

If I have more than one treatment unit matched to only 1 control (case of subclass 4), I would like to discard such control unit as to keep the same number of controls and units for each subclass. In the end, I expect to have an equal number of observations for which treat = 1 and treat = 0.

My real dataset is huge and consists of more than a million subclasses.

structure(list(id = c("NSW1", "NSW57", "PSID6", "PSID84", "PSID147", 
"PSID349", "PSID361", "PSID400", "NSW2", "NSW6", "NSW9", "NSW60", 
"NSW77", "NSW80", "NSW127", "NSW161", "NSW169", "NSW177", "NSW179", 
"PSID15", "PSID31", "PSID41", "PSID62", "PSID92", "PSID93", "PSID150", 
"PSID167", "PSID178", "PSID254", "PSID292", "PSID300", "PSID308", 
"PSID309", "PSID314", "PSID330", "NSW3", "NSW55", "NSW109", "PSID1", 
"PSID69", "PSID91", "PSID165", "PSID166", "PSID302", "PSID378", 
"ASID9033", "ASID9034", "ASID9036"), treat = c(1L, 1L, 0L, 0L, 
0L, 0L, 0L, 0L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 0L, 
0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 1L, 
1L, 1L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 1L, 1L, 0L), age = c(37L, 
37L, 37L, 37L, 37L, 37L, 37L, 37L, 22L, 22L, 22L, 22L, 22L, 22L, 
22L, 22L, 22L, 22L, 22L, 22L, 22L, 22L, 22L, 22L, 22L, 22L, 22L, 
22L, 22L, 22L, 22L, 22L, 22L, 22L, 22L, 30L, 30L, 30L, 30L, 30L, 
30L, 30L, 30L, 30L, 30L, 29L, 29L, 29L), race = c("black", "black", 
"black", "hispan", "white", "white", "white", "black", "hispan", 
"black", "black", "white", "black", "black", "black", "black", 
"black", "hispan", "white", "black", "hispan", "black", "white", 
"white", "white", "hispan", "white", "white", "white", "white", 
"black", "black", "white", "white", "black", "black", "black", 
"black", "white", "black", "white", "white", "white", "white", 
"white", "black", "white", "black"), married = c(1L, 0L, 1L, 
0L, 1L, 1L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 1L, 1L, 0L, 
1L, 0L, 1L, 1L, 1L, 0L, 0L, 1L, 0L, 1L, 1L, 1L, 1L, 1L, 0L, 0L, 
0L, 1L, 0L, 1L, 0L, 0L, 1L, 1L, 1L, 1L, 1L, 0L, 0L), subclass = c(1L, 
1L, 1L, 1L, 1L, 1L, 1L, 1L, 2L, 2L, 2L, 2L, 2L, 2L, 2L, 2L, 2L, 
2L, 2L, 2L, 2L, 2L, 2L, 2L, 2L, 2L, 2L, 2L, 2L, 2L, 2L, 2L, 2L, 
2L, 2L, 3L, 3L, 3L, 3L, 3L, 3L, 3L, 3L, 3L, 3L, 4L, 4L, 4L)), class = "data.frame", row.names = c(NA, 
-48L))

3

There are 3 best solutions below

3
On BEST ANSWER

Here's a (maybe a bit convoluted) way using group_split and map_dfr.

library(tidyverse)

df %>% 
  group_split(subclass) %>% 
  map_dfr(~ if(sum(.x$treat) > (nrow(.x) / 2)) bind_rows(.x[.x$treat == 0, ], sample_n(.x[.x$treat == 1, ], nrow(.x[.x$treat == 0, ]))) 
          else if(sum(.x$treat) < (nrow(.x) / 2)) bind_rows(.x[.x$treat == 1, ], sample_n(.x[.x$treat == 0, ], nrow(.x[.x$treat == 1, ]))) 
          else .x)

# A tibble: 34 x 6
   id      treat   age race   married subclass
   <chr>   <int> <int> <chr>    <int>    <int>
 1 NSW1        1    37 black        1        1
 2 NSW57       1    37 black        0        1
 3 PSID400     0    37 black        0        1
 4 PSID84      0    37 hispan       0        1
 5 NSW2        1    22 hispan       0        2
 6 NSW6        1    22 black        0        2
 7 NSW9        1    22 black        0        2
 8 NSW60       1    22 white        0        2
 9 NSW77       1    22 black        0        2
10 NSW80       1    22 black        0        2
# ... with 24 more rows
0
On

Here's one simple approach

library(tidyverse)
set.seed(999)
mydata %>% 
  mutate(r = runif(n = nrow(mydata))) %>%
  arrange(r) %>%
  group_by(treat, subclass) %>% 
  mutate(max_r = max(r)) %>% 
  filter(r == max_r) %>% select(-c(r, max_r)) -> mydata.filtered

I first create a random number r, then I arrange the data based on r. Thereafter I calculate max(r) for each subclass x treat cell and drop everything where max(r) != r.

This results in 1 treated and 1 non-treated obs for each subclass.

> table(mydata.filtered$treat, mydata.filtered$subclass)
   
    1 2 3 4
  0 1 1 1 1
  1 1 1 1 1

data

mydata<- structure(list(id = c("NSW1", "NSW57", "PSID6", "PSID84", "PSID147", 
                      "PSID349", "PSID361", "PSID400", "NSW2", "NSW6", "NSW9", "NSW60", 
                      "NSW77", "NSW80", "NSW127", "NSW161", "NSW169", "NSW177", "NSW179", 
                      "PSID15", "PSID31", "PSID41", "PSID62", "PSID92", "PSID93", "PSID150", 
                      "PSID167", "PSID178", "PSID254", "PSID292", "PSID300", "PSID308", 
                      "PSID309", "PSID314", "PSID330", "NSW3", "NSW55", "NSW109", "PSID1", 
                      "PSID69", "PSID91", "PSID165", "PSID166", "PSID302", "PSID378", 
                      "ASID9033", "ASID9034", "ASID9036"), treat = c(1L, 1L, 0L, 0L, 
                                                                     0L, 0L, 0L, 0L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 0L, 
                                                                     0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 1L, 
                                                                     1L, 1L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 1L, 1L, 0L), age = c(37L, 
                                                                                                                              37L, 37L, 37L, 37L, 37L, 37L, 37L, 22L, 22L, 22L, 22L, 22L, 22L, 
                                                                                                                              22L, 22L, 22L, 22L, 22L, 22L, 22L, 22L, 22L, 22L, 22L, 22L, 22L, 
                                                                                                                              22L, 22L, 22L, 22L, 22L, 22L, 22L, 22L, 30L, 30L, 30L, 30L, 30L, 
                                                                                                                              30L, 30L, 30L, 30L, 30L, 29L, 29L, 29L), race = c("black", "black", 
                                                                                                                                                                                "black", "hispan", "white", "white", "white", "black", "hispan", 
                                                                                                                                                                                "black", "black", "white", "black", "black", "black", "black", 
                                                                                                                                                                                "black", "hispan", "white", "black", "hispan", "black", "white", 
                                                                                                                                                                                "white", "white", "hispan", "white", "white", "white", "white", 
                                                                                                                                                                                "black", "black", "white", "white", "black", "black", "black",                                                                                                                                                                                 "black", "white", "black", "white", "white", "white", "white", 
                                                                                                                                                                                "white", "black", "white", "black"), married = c(1L, 0L, 1L, 
                                                                                                                                                                                                                                 0L, 1L, 1L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 1L, 1L, 0L, 
                                                                                                                                                                                                                                 1L, 0L, 1L, 1L, 1L, 0L, 0L, 1L, 0L, 1L, 1L, 1L, 1L, 1L, 0L, 0L, 
                                                                                                                                                                                                                                 0L, 1L, 0L, 1L, 0L, 0L, 1L, 1L, 1L, 1L, 1L, 0L, 0L), subclass = c(1L, 
                                                                                                                                                                                                                                                                                                   1L, 1L, 1L, 1L, 1L, 1L, 1L, 2L, 2L, 2L, 2L, 2L, 2L, 2L, 2L, 2L, 
                                                                                                                                                                                                                                                                                                   2L, 2L, 2L, 2L, 2L, 2L, 2L, 2L, 2L, 2L, 2L, 2L, 2L, 2L, 2L, 2L, 
                                                                                                                                                                                                                                                                                                   2L, 2L, 3L, 3L, 3L, 3L, 3L, 3L, 3L, 3L, 3L, 3L, 4L, 4L, 4L)), class = "data.frame", row.names = c(NA, 
                                                                                                                                                                                                                                                                                                                                                                                                     -48L))

?MatchIt seems to also supply the ratio argument, which can be used to force 1-to-1 matching within the matching function call.

0
On

Another (base R) approach:

md <- do.call("rbind", unname(lapply(split(md, ~subclass),
                                     function(x) {
                                         x[c(which(x$treat == 1)[1], 
                                             which(x$treat == 0)[1]),]
                                     })))

Grabs the first treated and first control unit from each subclass then rbinds them all together. If your data are randomly ordered this is equivalent to randomly selecting one treated and one control unit.