11

I am looking for an RAM efficient way to calculate the median over a complement set with the help of data.table.

For a set of observations from different groups, I am interested in an implementation of a median of "other groups". I.e., if a have a data.table with one value column and one grouping column, I want for each group calculate the median of values in all other group except the current group. E.g. for group 1 we calculate the median from all values except the values that belong to group 1, and so on.

A concrete example data.table

dt <- data.table(value = c(1,2,3,4,5), groupId = c(1,1,2,2,2))
dt
#    value groupId
# 1:     1       1
# 2:     2       1
# 3:     3       2
# 4:     4       2
# 5:     5       2

I would like the medianOfAllTheOtherGroups to be defined as 1.5 for group 2 and defined as 4 for group 1, repeated for each entry in the same data.table:

dt <- data.table(value = c(1,2,3,4,5), groupId = c(1,1,2,2,2), medianOfAllTheOtherGroups = c(4, 4, 1.5, 1.5, 1.5))

dt
#    value groupId medianOfAllTheOtherGroups
# 1:     1       1                       4.0 # median of all groups _except_ 1
# 2:     2       1                       4.0
# 3:     3       2                       1.5 # median of all groups _except_ 2
# 4:     4       2                       1.5  
# 5:     5       2                       1.5

To calculate the median for each group only once and not for each observation, we went for an implementation with a loop. The current complete implementation works nice for small data.tables as input, but suffers from large RAM consumption for larger data sets a lot with the medians called in a loop as bottleneck (Note: for the real use case we have a dt with 3.000.000 rows and 100.000 groups). I have worked very little with improving RAM consumption. Can an expert help here to improve RAM for the minimal example that I provide below?

MINIMAL EXAMPLE:

library(data.table)
set.seed(1)
numberOfGroups <- 10
numberOfValuesPerGroup <- 100

# Data table with column
# groupIds - Ids for the groups available
# value - value we want to calculate the median over
# includeOnly - boolean that indicates which example should get a "group specific" median
dt <-
  data.table(
    groupId = as.character(rep(1:numberOfGroups, each = numberOfValuesPerGroup)),
    value = round(runif(n = numberOfGroups * numberOfValuesPerGroup), 4)
  )

# calculate the median from all observations for those groups that do not 
# require a separate treatment
medianOfAllGroups <-  median(dt$value)
dt$medianOfAllTheOtherGroups <- medianOfAllGroups


# generate extra data.table to collect results for selected groups
includedGroups <-  dt[, unique(groupId)]
dt_otherGroups <- 
  data.table(groupId = includedGroups,
             medianOfAllTheOtherGroups =  as.numeric(NA)
  )

# loop over all selected groups and calculate the median from all observations
# except of those that belong to this group
for (id in includedGroups){
  dt_otherGroups[groupId == id, 
                 medianOfAllTheOtherGroups := median(dt[groupId != id, value])]
}

# merge subset data to overall data.table
dt[dt_otherGroups, medianOfAllTheOtherGroups := i.medianOfAllTheOtherGroups, 
   on = c("groupId")]

PS: here the example output for 10 groups with 100 observations each:

dt
#      groupId  value medianOfAllTheOtherGroups
#   1:       1 0.2655                   0.48325
#   2:       1 0.3721                   0.48325
#   3:       1 0.5729                   0.48325
#   4:       1 0.9082                   0.48325
#   5:       1 0.2017                   0.48325
# ---
#  996:      10 0.7768                   0.48590
#  997:      10 0.6359                   0.48590
#  998:      10 0.2821                   0.48590
#  999:      10 0.1913                   0.48590
# 1000:      10 0.2655                   0.48590

Some numbers for different settings of the minimal example (tested on a Mac Book Pro with 16Gb RAM):

NumberOfGroups numberOfValuesPerGroup Memory (GB) Runtime (s)
500 50 0.48 1.47
5000 50 39.00 58.00
50 5000 0.42 0.65

All memory values were extracted from the output of profvis, see example screenshot for the smallest example here: profvisoutput

slamballais
  • 2,599
  • 1
  • 14
  • 28
  • 2
    You may also see discussions on _how_ to measure memory: [data.table vs dplyr memory use revisited](https://stackoverflow.com/questions/61376970/data-table-vs-dplyr-memory-use-revisited); [Memory profiling with data.table](https://stackoverflow.com/questions/58278838/memory-profiling-with-data-table). @jangorecki What's the state of the art? Cheers – Henrik Mar 12 '21 at 20:45

4 Answers4

4

How about this approach:

setkey(dt, groupId)
dt[, median_val := median(dt$value[dt$groupId != groupId]), by = .(groupId)]

For the 5000 groups with 50 values each case this took ~34 seconds on my MBP. Haven't checked RAM usage though.

Edit: here's another version with two changes, (1) using collapse::fmedian as suggested by Henrik and (2) pre-aggregating the values into a list column by group.

d2 = dt[, .(value = list(value)), keyby = .(groupId)]
setkey(dt, groupId)
dt[, median_val := 
     fmedian(d2[-.GRP, unlist(value, use.names = FALSE, recursive = FALSE)]), 
   by = .(groupId)]  

This took around 18 seconds for the 5000/50 example on my machine.

RAM usage: approach 1 ~28GB approach 2 ~15GB according to profvis

talat
  • 62,625
  • 18
  • 110
  • 141
  • 1
    Thank you @talat and @Henrik. Using `collapse::fmedian` is nice as it also allows weighted medians, which I happen to know is of interest for OP :) – Christian Borck Mar 14 '21 at 10:24
  • @talat: Your solution with fmedian looks really cool. However, when implementing it to "practice", I have the problem that I need the response of the function to be a vector ordered in the same order as the input vector (because in fact the whole function is called in a by data.table statement). This means I need a reordering of the data table (such as henrik proposed) and it seems this kills everything somehow. And the results are different to before (but on the large final scale - I don't have time to dig into details currently, so I will postpone this to later). – Julia Hillmann Mar 16 '21 at 08:14
4

Disclaimer: For some reason the profiling keeps crashing my session, so unfortunately I have no such results. However, because my alternatives were a bit faster than OP, I thought it could still be worth posting them so that OP may assess their memory use.


Data

# numberOfGroups <- 5000
# numberOfValuesPerGroup <- 50
# dt <- ...as in OP...
d1 = copy(dt)
d1[ , ri := .I] # just to able to restore original order when comparing result with OP
d2 = copy(dt)
d3 = copy(dt)

Explanation

I shamelessly borrow lines 28, 30-32 from median.default to make a stripped-down version of median.

Calculate total number of rows in the original data (nrow(d1)). Order data by 'value' (setorder). By ordering, two instances of sort in the median code can be removed.

For each 'groupID' (by = groupId):

Calculate length of "other" (number of rows in the original data minus number of rows of current group (.N)).

Calculate median, where the input values are d1$value[-.I], i.e. the original values except the indices of the current group; ?.I: "While grouping, it holds for each item in the group, its row location in x".

Code & Timing

system.time({

  # number of rows in original data    
  nr = nrow(d1)

  # order by value
  setorder(d1, value)
  
  d1[ , med := {
    
    # length of "other"
    n = nr - .N
    
    # ripped from median
    half = (n + 1L) %/% 2L
    if (n %% 2L == 1L) d1$value[-.I][half]
    else mean(d1$value[-.I][half + 0L:1L])
    
  }, by = groupId]
})
  
# user  system elapsed 
# 4.08    0.01    4.07

# OP's code on my (old) PC
#  user  system elapsed 
# 84.02    7.26   86.75 

# restore original order & check equality
setorder(d1, ri)
all.equal(dt$medianOfAllTheOtherGroups, d1$med)
# [1] TRUE 

Comparison with base::median & collapse::fmedian

I also tried the "-.I" with base::median and collapse::fmedian, where the latter was about twice as fast as base::median.

system.time(
  d2[ , med := median(d2$value[-.I]), by = groupId]
)
#   user  system elapsed 
#  26.86    0.02   26.85 

library(collapse)
system.time(
  d3[ , med := fmedian(d3$value[-.I]), by = groupId]
)
#   user  system elapsed 
#  16.95    0.00   16.96  

all.equal(dt$medianOfAllTheOtherGroups, d2$med)
# TRUE

all.equal(dt$medianOfAllTheOtherGroups, d3$med)
# TRUE

Thanks a lot to @Cole for helpful comments which improved the performance.

Henrik
  • 56,228
  • 12
  • 124
  • 139
  • 1
    Nice! Your `d1` approach results in following timings/memory according to profvis on my machine for numberOfGroups = 5000 and numberOfValuesPerGroup = 50: **5.5GB, 10500ms** – Christian Borck Mar 14 '21 at 10:19
  • 2
    nice answer, fastest of the present alternatives according to my benchmarks and also most ram efficient – talat Mar 14 '21 at 10:57
  • 1
    @ChristianBorck and talat, Thank you for your feedback and for taking your time to run profvis! – Henrik Mar 14 '21 at 10:57
  • 1
    Thanks a lot, Henrik, talat and Christian! Your approaches seem to be really nice - I just need a bit of time to see if they can be implemented in the slightly more complex "real" scenario that we face. Thanks a lot so far, I will tell you the outcome if it as soon as I tested them in "practice" ;) – Julia Hillmann Mar 14 '21 at 20:28
  • Thanks again! However, I assume that we will need to evaluate fmedian instead of the direct median implementation as we need a wtd median instead of a normal one (we thought the minimal example would be easier to read with "normal" median, but we didn't think that so many responses will come in the direction of the median calculation itself - next time I know better ;) ). I have to check in some days, unfortunately we are very busy currently. But I let you know what worked out best! – Julia Hillmann Mar 16 '21 at 08:08
  • Could you elaborate how this actually is more RAM efficient? It looks like those all require that all the data is loaded in the memory for sorting? No? – ooxio Mar 17 '21 at 10:19
4

The median is the midpoint of a dataset that's been ordered. For an odd number of values in a dataset, the median is simply the middle number. For an even number of values in a dataset, the median is the mean of the two numbers closest to the middle.

To demonstrate, consider the simple vector of 1:8

1 | 2 | 3 |** 4 | 5 **| 6 | 7 | 8

In this case, our midpoint is 4.5. And because this is a very simple example, the median itself is 4.5

Now consider groupings where one grouping is the first value of the vector. That is, our group is only 1. We know that this will shift our median towards the right (i.e. larger) because we removed a low value of the distribution. Our new distribution is 2:8 and the median is now 5.

2 | 3 | 4 | *5* | 6 | 7 | 8

This is only interesting if we can determine a relationship between these shifts. Specifically, our original midpoint was 4.5. Our new midpoint based on the original vector is 5.

Let's demonstrate a larger mixture with a group of 1, 3, and 7. In this case, we have 2 values below the original midpoint and one value above the original midpoint. Our new median is 5:

2 | 4 | ** 5 ** | 6 | 8

So empirically, we have determined that shifting removing smaller numbers from the distribution shifts our mid_point index by 0.5 and removing larger numbers from the distribution shifts our mid_point index by -0.5. There are a few other stipulations:

We need to make sure that our grouping index is not in the new mid_point calculation. Consider a group of 1, 2, and 5. Based on my math, we would shift up by 0.5 based on (2 below - 1 above) / 2 for a new mid_point of 5. That's wrong because 5 was already used up! We need to account for this.

3 | 4 | ** 6 ** | 7 | 8

Likewise, with our shifted mid_point, we also need to look back to verify that our ranking values are still aligned. In a sequence of 1:20, consider a group of c(1:9, 11). While 11 is originally above the original mid_point of 10.5, it is not above the shifted mid_point of (9 below - 1 above ) / 2 14.5. But our actual median would be 15.5 because 11 is now below the new mid_way point.

10 | 12 | 13 | 14 | ** 15 | 16 **| 17 | 18 | 19 | 20

TL:DR what's the code??

All of the examples above, the grouping's rankings vector are given in by the special symbol I assuming we did setorder(). If we do the same math as above, we don't have to waste time subsetting the dataset. We can instead determine what the new index(es) should be based on what's been removed from the distribution.


setorder(dt, value)  

nr = nrow(dt)
is_even = nr %% 2L == 0L
mid_point = (nr + 1L) / 2L

dt[, medianOfAllTheOtherGroups :=
     {
       below = sum(.I < mid_point)
     is_midpoint = is_even && below && (.I[below] + 1L == mid_point)
     
     above = .N - below - is_midpoint
     new_midpoint = (below - above) / 2L + mid_point
     ## TODO turn this into a loop incase there are multiple values that this is true
     if (new_midpoint > mid_point && above &&.I[below + 1] < new_midpoint) { ## check to make sure that none of the indices were above
       below = below - 1L
       new_midpoint = new_midpoint + 1L
     } else if (new_midpoint < mid_point && below && .I[below] > new_midpoint) {
       below = below + 1L
       new_midpoint = new_midpoint - 1L
     }
     if (((nr - .N + 1L) %% 2L) == 0L) {
       dt$value[new_midpoint]
     } else {
       ##TODO turn this into a loop in case there are multiple values that this is true for.
       default_inds = as.integer(new_midpoint + c(-0.5, 0.5))
       if (below) {
         if (.I[below] == default_inds[1L])
           default_inds[1L] = .I[below] - 1L
       }
       if (above) {
         if (.I[below + 1L + is_midpoint] == default_inds[2L])
           default_inds[2L] = .I[below + 1L] + 1L
       }
       mean(dt$value[default_inds])
     }
     }
   , by = groupId]

Performance

This is using bench::mark which checks that all results are equal. FOr Henrik and my solutions, I do reorder the results back to the original grouping so that they are all equal.

Note that while this (complicated) algorithm is most efficient, I do want to emphasize that most of these likely do not extreme peak RAM usage. The other answers have to subset 5,000 times to allocate a vector of length 249,950 to calculate a new median. That's about 2 MB per loop just on allocation (e.g., 10 GB overall).

# A tibble: 6 x 13
  expression            min   median `itr/sec` mem_alloc `gc/sec` n_itr  n_gc total_time result          memory        time    gc      
  <bch:expr>       <bch:tm> <bch:tm>     <dbl> <bch:byt>    <dbl> <int> <dbl>   <bch:tm> <list>          <list>        <list>  <list>  
1 cole              225.7ms  271.8ms    3.68      6.34MB    
2 henrik_smart_med    17.7s    17.7s    0.0564   23.29GB    
3 henrik_base_med      1.6m     1.6m    0.0104   41.91GB    
4 henrik_fmed         55.9s    55.9s    0.0179   32.61GB    
5 christian_lookup    54.7s    54.7s    0.0183   51.39GB    
6 talat_unlist        35.9s    35.9s    0.0279   19.02GB     
Full profile code
library(data.table)
library(collapse)
set.seed(76)
numberOfGroups <- 5000
numberOfValuesPerGroup <- 50

dt <-
  data.table(
    groupId = (rep(1:numberOfGroups, each = numberOfValuesPerGroup)),
    value = round(runif(n = numberOfGroups * numberOfValuesPerGroup, 0, 10), 4)
  )

## this is largely instantaneous.
dt[ , ri := .I]

bench::mark( cole = {
  setorder(dt, value)
  
  nr = nrow(dt)
  is_even = nr %% 2L == 0L
  mid_point = (nr + 1L) / 2L
  
  dt[, medianOfAllTheOtherGroups :=
       {
         below = sum(.I < mid_point)
         is_midpoint = is_even && below && (.I[below] + 1L == mid_point)
         
         above = .N - below - is_midpoint
         new_midpoint = (below - above) / 2L + mid_point
         ## TODO turn this into a loop incase there are multiple values that this is true
         if (new_midpoint > mid_point && above &&.I[below + 1] < new_midpoint) { ## check to make sure that none of the indices were above
           below = below - 1L
           new_midpoint = new_midpoint + 1L
         } else if (new_midpoint < mid_point && below && .I[below] > new_midpoint) {
           below = below + 1L
           new_midpoint = new_midpoint - 1L
         }
         if (((nr - .N + 1L) %% 2L) == 0L) {
           as.numeric(dt$value[new_midpoint])
         } else {
           ##TODO turn this into a loop in case there are multiple values that this is true for.
           default_inds = as.integer(new_midpoint + c(-0.5, 0.5))
           if (below) {
             if (.I[below] == default_inds[1L])
               default_inds[1L] = .I[below] - 1L
           }
           if (above) {
             if (.I[below + 1L + is_midpoint] == default_inds[2L])
               default_inds[2L] = .I[below + 1L] + 1L
           }
           mean(dt$value[default_inds])
         }
       }
     , by = groupId]
  
  setorder(dt, ri)

},
henrik_smart_med = {
  
  # number of rows in original data    
  nr = nrow(dt)
  
  # order by value
  setorder(dt, value)
  
  dt[ , medianOfAllTheOtherGroups := {
    
    # length of "other"
    n = nr - .N
    
    # ripped from median
    half = (n + 1L) %/% 2L
    if (n %% 2L == 1L) dt$value[-.I][half]
    else mean(dt$value[-.I][half + 0L:1L])
    
  }, by = groupId]
  setorder(dt, ri)
},
henrik_base_med = {
  dt[ , med := median(dt$value[-.I]), by = groupId]
},
henrik_fmed = {
  dt[ , med := fmedian(dt$value[-.I]), by = groupId]
}, 
christian_lookup = {
  nrows <- dt[, .N]
  dt_match <- dt[, .(nrows_other = nrows- .N), by = .(groupId_match = groupId)]
  dt_match[, odd := nrows_other %% 2]
  dt_match[, idx1 := ceiling(nrows_other/2)]
  dt_match[, idx2 := ifelse(odd, idx1, idx1+1)]
  
  setkey(dt, value)
  dt_match[, medianOfAllTheOtherGroups := dt[groupId != groupId_match][c(idx1, idx2), sum(value)/2], by = groupId_match]
  dt[dt_match, medianOfAllTheOtherGroups := i.medianOfAllTheOtherGroups, 
     on = c(groupId = "groupId_match")]
},
talat_unlist = {
  d2 = dt[, .(value = list(value)), keyby = .(groupId)]
  setkey(dt, groupId)
  dt[, medianOfAllTheOtherGroups := 
       fmedian(d2[-.GRP, unlist(value, use.names = FALSE, recursive = FALSE)]), 
     by = .(groupId)]  
})
Cole
  • 9,745
  • 1
  • 7
  • 21
  • Hi @Cole! In order to learn, I walked through your code on a simple data set: `dt = data.table(value = 1:8, groupId = c(1,1,1,2,2,3,3,3))`. If I have understood OP correctly, the median for group 1 would be calculated as median for all other groups _except_ g1, i.e. median of g2 and g3 = `median(4:8)` = 6; group 2: median of g1 and g3 = `median(c(1:3, 6:8))` = 4.5; group 3: median of g1 and g2 = `median(1:5)` = 3. I don't manage to get those values with your code. Can you please help me understand what is going on. In the large test data, all group lengths were even, here some are odd? Cheers – Henrik Mar 14 '21 at 00:17
  • @Henrik I'm still trying to figure it out myself! I know that I oversimplified it between even vs. odd vectors. But my thought process is that instead of recalculating medians, we only need to see what the new index(es) would be for the new median based on the values that are now being excluded. I'll ping you once I figure it out. Or if I messed up my logic, I'll delete. I do think I will be going to compiled code because to answer your question, I would need to figure out for group 2 that I need to take the average of 3 and 6 to get the median. – Cole Mar 14 '21 at 01:19
  • 1
    Thanks for your feedback. I'm anticipating a "positive ping"! Good luck! :) – Henrik Mar 14 '21 at 01:24
  • @Henrik I think this is a largely positive ping. Could you let me know if there are holes in this? I feel like if I had more mathematical background, there would be a simpler method of saying what I want to say. With 5,000 groups, it's equal to your answer for the few seeds I've tried. As for performance, it takes 2.5 seconds on my machine to do 50,000 groups so if it holds up, it should be very performant. – Cole Mar 16 '21 at 00:19
  • Wow! Impressive! Thanks for the ping. I haven't time to look closer at it right now though. – Henrik Mar 16 '21 at 01:47
  • Thanks for your detailed answer Cole! I will check it later! Unfortunately we use a wtd median in reality (we thought for the minimal example it would be easier to read with normal median but we didn't expect to get so many responses that refer to the logic of the median itself), so I need to check in more detail if your solution is extendable (but unfortunately I am very busy these days and didn't expect so many quick reactions ;) ) – Julia Hillmann Mar 16 '21 at 08:05
  • Well... for the question I think this solution is very effective. For a weighted median, which I still do not fully understand the algorithm, I do believe you could do some fancy math on it to apply a similar method. I would probably lean towards Henrik's solution for simplicity. – Cole Mar 17 '21 at 00:42
  • So sorry for the delay. Thank you so much @Henrik. You are too kind and I appreciate it. – Cole May 25 '21 at 23:27
3

Approach for exact results: Median is "the middle" value of a sorted vector. (or mean of two middle values for even length vector) If we know the length of the sorted vector of others, we can directly look up the corresponding vector element(s) index for the median thus avoiding actually computing the median n*groupId times:

library(data.table)
set.seed(1)
numberOfGroups <- 5000
numberOfValuesPerGroup <- 50

dt <-
  data.table(
    groupId = as.character(rep(1:numberOfGroups, each = numberOfValuesPerGroup)),
    value = round(runif(n = numberOfGroups * numberOfValuesPerGroup), 4)
  )

# group count match table + idx position for median of others
nrows <- dt[, .N]
dt_match <- dt[, .(nrows_other = nrows- .N), by = .(groupId_match = groupId)]
dt_match[, odd := nrows_other %% 2]
dt_match[, idx1 := ceiling(nrows_other/2)]
dt_match[, idx2 := ifelse(odd, idx1, idx1+1)]

setkey(dt, value)
dt_match[, medianOfAllTheOtherGroups := dt[groupId != groupId_match][c(idx1, idx2), sum(value)/2], by = groupId_match]
dt[dt_match, medianOfAllTheOtherGroups := i.medianOfAllTheOtherGroups, 
 on = c(groupId = "groupId_match")]

There might be more data.table-ish ways improving performance further, I guess.

Memory/runtime for numberOfGroups = 5000 and numberOfValuesPerGroup = 50: 20GB, 27000ms

Christian Borck
  • 1,652
  • 1
  • 10
  • 17