How to filter a data.table based on an uncertain number of conditions?

67 Views Asked by At

Given the following data.table in R:

set.seed(123666)
dt <- data.table(sample1 = sample(10), 
                 sample2 = sample(10),
                 sample3 = sample(10),
                 sample4 = sample(10), 
                 sample5 = sample(10),
                 sample6 = sample(10))
dt
    sample1 sample2 sample3 sample4 sample5 sample6
 1:       2       6       3       9       1       2
 2:      10       9      10       3       7       5
 3:       6      10       8       5       5       1
 4:       8       2       9       8       6       6
 5:       5       4       5      10      10       8
 6:       7       1       7       4       4      10
 7:       4       3       1       6       3       7
 8:       1       5       6       1       2       3
 9:       3       7       2       2       8       9
10:       9       8       4       7       9       4

Let's assume the first 3 samples are in group_a and the last 3 samples are in group_b. Now we want to filter rows that satisfy the condition of having at least 2 out of 3 samples greater than 2 in each group. In the given case, we can achieve this using the following code:

group_a <- paste0('sample', seq(1,3))
group_b <- paste0('sample', seq(4,6))

dt[rowSums(dt[, ..group_a, with = FALSE] > 2) >= 2 & rowSums(dt[, ..group_b, with = FALSE] > 2) >= 2]
   sample1 sample2 sample3 sample4 sample5 sample6
1:      10       9      10       3       7       5
2:       6      10       8       5       5       1
3:       8       2       9       8       6       6
4:       5       4       5      10      10       8
5:       7       1       7       4       4      10
6:       4       3       1       6       3       7
7:       3       7       2       2       8       9
8:       9       8       4       7       9       4

Now, let's consider a data.table where each column still represents a sample name, but the number of samples is uncertain. There is an additional variable group describing the grouping of samples:

group <- paste0('sample', seq(1,6))
group_id <- c(rep('group_a', 3), rep('group_b', 3))
names(group) <- group_id 
group
  group_a   group_a   group_a   group_b   group_b   group_b 
"sample1" "sample2" "sample3" "sample4" "sample5" "sample6"

How to accomplish this task using the data.table syntax and with the most concise code possible?

2

There are 2 best solutions below

4
lotus On BEST ANSWER

You can split on the names and iterate over the list to subset the columns and check your conditions then reduce the result to subset the rows:

library(data.table)

dt[Reduce(`&`, lapply(split(group, names(group)), \(x) rowSums(dt[, .SD, .SDcols = x] > 1) >= 2 )), ]

   sample1 sample2 sample3 sample4 sample5 sample6
 1:       2       6       3       9       1       2
 2:      10       9      10       3       7       5
 3:       6      10       8       5       5       1
 4:       8       2       9       8       6       6
 5:       5       4       5      10      10       8
 6:       7       1       7       4       4      10
 7:       4       3       1       6       3       7
 8:       1       5       6       1       2       3
 9:       3       7       2       2       8       9
10:       9       8       4       7       9       4

All rows meet your example criteria, but if we change it to at least two values greater than two, we can see that it works:

dt[Reduce(`&`, lapply(split(group, names(group)), \(x) rowSums(dt[, .SD, .SDcols = x] > 2) >= 2 )), ]

   sample1 sample2 sample3 sample4 sample5 sample6
1:      10       9      10       3       7       5
2:       6      10       8       5       5       1
3:       8       2       9       8       6       6
4:       5       4       5      10      10       8
5:       7       1       7       4       4      10
6:       4       3       1       6       3       7
7:       3       7       2       2       8       9
8:       9       8       4       7       9       4

@r2evans suggested an alternative that might offer better performance in the context of a large number of groups.

dt[rowSums(sapply(split(group, names(group)), \(x) rowSums(dt[, .SD, .SDcols = x] <= 2) >= 2 )) == 0, ]
0
langtang On

Here is an alternative that is not as slick as @Ritchie Sacramento's solution above, but might be of interest:

Make a long version of the data

long_dt = melt(dt[,id:=.I], "id")[data.table(variable = group, group=names(group)), on="variable"] 

Make a helper function

f <- \(d,v,n) d[, sum(value>v),group][, sum(V1>=n)>=.N]

Use the helper function for any values v and n. For example to filter rows where each group has at least 2 values exceeding 4, you can do:

dt[long_dt[,f(.SD,4,2),id][V1==T,id]]

Output:

   sample1 sample2 sample3 sample4 sample5 sample6    id
     <int>   <int>   <int>   <int>   <int>   <int> <int>
1:      10       9      10       3       7       5     2
2:       6      10       8       5       5       1     3
3:       8       2       9       8       6       6     4
4:       5       4       5      10      10       8     5
5:       9       8       4       7       9       4    10