December 28, 2021

Running Statistics for Pytorch

Here is, a useful little module for computing efficient online GPU statistics in Pytorch.

Pytorch is great for working with small batches of data: if you want to do some calculations over 100 small images, all the features fit into a single GPU and the pytorch functions are perfect.

But what if your data doesn't fit in the GPU all at once? What if they don't even fit into CPU RAM? For example, how would you calculate the median values of a set of a few thousand language features over all of Wikipedia tokens? If the data is small, it's easy: just sort them all and take the middle. But if they don't fit - what to do?

import datasets, runningstats
ds = datasets.load_dataset('wikipedia', '20200501.en')['train']
q = runningstats.Quantile()
for batch in tally(q, ds, batch_size=100, cache='quantile.npz'):
  feats = compute_features_from_batch(batch)
  q.add(feats) # dim 0 is batch dim; dim 1 is feature dim.
print('median for each feature', q.quantile(0.5))

Here, online algorithms come to the rescue. These are economical algorithms that summarize an endless stream of data using only a small amount of memory. Online algorithms are particularly handy for digesting big data on a GPU where memory is precious. includes running Stat objects for Mean, Variance, Covariance, TopK, Quantile, Bincount, IoU, SecondMoment, CrossCovariance, CrossIoU, as well as an object to accumulate CombinedStats....

They are all single-pass online algorithms that are tested for numerical stability and good performance. They work on both CPU and GPU with all the pytorch data types (whatever device and type of data is passed in, the algorithms will track statistics using the same device and type).

Using an online statistic object is simple. If we just want the variance of a stream of scalars, we feed the stream to a Variance object:

from util.runningstats import Variance
m = Variance()
print(f'mean {m.mean()}, variance {m.variance()}')
mean 3.4285714626312256, variance 1.4946550130844116

The variance is always kept up-to-date as more data is added. In the above, the Variance object uses the classic online algorithm by Chan (1979) to compute mean and variance in one pass, rather than requiring a second pass to subtract the mean. Avoiding a second pass is nice if it takes several hours to gather all the data. And it makes the code cleaner.

The quantile estimator is based on the modern randomized optimal quantile estimator by Karin, Lang and Liberty (FOCS 2016) - it's the best online method to get a median, and at the same time, you also get estimators for every other quantile in the distribution.

Because calculating large-scale statistics is often time-consuming, the objects include a cache capability: each stat can be saved and loaded from a file, and there is a tally function that will compute a statistic once but then reload it from a file if it has been computed before. Here's another example:

ds = MyDataSet()
cs = CombinedStat(
# Tally will cache the stats and set up a Dataloader if needed.
for [batch] in tally(cs, ds, cache="savedstats.npz", batch_size=100,
        pin_memory=True, progress=tqdm):
  signals = my_computation(batch.cuda())
# After savedstats.npz is saved once, the cache is reused.

I have found these classes pretty handy for keeping large-scale statistical estimation code succinct and error-free. Hopefully you will find them useful too. Posted by David at December 28, 2021 02:23 PM

Post a comment

Remember personal info?