December 28, 2021

Running Statistics for Pytorch

Here is runningstats.py, a useful little module for computing efficient online GPU statistics in Pytorch.

Pytorch is great for working with small batches of data: if you want to do some calculations over 100 small images, all the features fit into a single GPU and the pytorch functions are perfect.

But what if your data doesn't fit in the GPU all at once? What if they don't even fit into CPU RAM? For example, how would you calculate the median values of a set of a few thousand language features over all of Wikipedia tokens? If the data is small, it's easy: just sort them all and take the middle. But if they don't fit - what to do?

import datasets, runningstats
ds = datasets.load_dataset('wikipedia', '20200501.en')['train']
q = runningstats.Quantile()
for batch in tally(q, ds, batch_size=100, cache='quantile.npz'):
  feats = compute_features_from_batch(batch)
  q.add(feats) # dim 0 is batch dim; dim 1 is feature dim.
print('median for each feature', q.quantile(0.5))

Here, online algorithms come to the rescue. These are economical algorithms that summarize an endless stream of data using only a small amount of memory. Online algorithms are particularly handy for digesting big data on a GPU where memory is precious. runningstats.py includes running Stat objects for Mean, Variance, Covariance, TopK, Quantile, Bincount, IoU, SecondMoment, CrossCovariance, CrossIoU, as well as an object to accumulate CombinedStats....

Continue reading "Running Statistics for Pytorch"
Posted by David at 02:23 PM | Comments (0)