22

I've gone through the official doc. I'm having a hard time understanding what this function is used for and how it works. Can someone explain this in Layman terms?

I get an error for the example they provide, although the Pytorch version I'm using matches the documentation. Perhaps fixing the error, which I did, is supposed to teach me something? The snippet given in the documentation is:

   fold = nn.Fold(output_size=(4, 5), kernel_size=(2, 2))
   input = torch.randn(1, 3 * 2 * 2, 1)
   output = fold(input)
   output.size()

and the fixed snippet is:

   fold = nn.Fold(output_size=(4, 5), kernel_size=(2, 2))
   input = torch.randn(1, 3 * 2 * 2, 3 * 2 * 2)
   output = fold(input)
   output.size()

Thanks!

iacob
  • 7,935
  • 4
  • 26
  • 52
shoshi
  • 338
  • 1
  • 2
  • 7

3 Answers3

26

unfold and fold are used to facilitate "sliding window" operation (like convolutions).
Suppose you want to apply a function foo to every 5x5 window in a feature map/image:

from torch.nn import functional as f
windows = f.unfold(x, kernel_size=5)

Now windows has size of batch-(5*5*x.size(1))-num_windows, you can apply foo on windows:

processed = foo(windows)

Now you need to "fold" processed back to the original size of x:

out = f.fold(processed, x.shape[-2:], kernel_size=5)

You need to take care of padding, and kernel_size that may affect your ability to "fold" back processed to the size of x.
Moreover, fold sums over overlapping elements, so you might want to divide the output of fold by patch size.

Shai
  • 93,148
  • 34
  • 197
  • 325
9

One dimensional unfolding is easy:

x = torch.arange(1, 9).float()
print(x)
# dimension, size, step
print(x.unfold(0, 2, 1))
print(x.unfold(0, 3, 2))

Out:

tensor([1., 2., 3., 4., 5., 6., 7., 8.])
tensor([[1., 2.],
        [2., 3.],
        [3., 4.],
        [4., 5.],
        [5., 6.],
        [6., 7.],
        [7., 8.]])
tensor([[1., 2., 3.],
        [3., 4., 5.],
        [5., 6., 7.]])

Two dimensional unfolding (also called patching)

import torch
patch=(3,3)
x=torch.arange(16).float()
print(x, x.shape)
x2d = x.reshape(1,1,4,4)
print(x2d, x2d.shape)
h,w = patch
c=x2d.size(1)
print(c) # channels
# unfold(dimension, size, step)
r = x2d.unfold(2,h,1).unfold(3,w,1).transpose(1,3).reshape(-1, c, h, w)
print(r.shape)
print(r) # result
tensor([ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10., 11., 12., 13.,
        14., 15.]) torch.Size([16])
tensor([[[[ 0.,  1.,  2.,  3.],
          [ 4.,  5.,  6.,  7.],
          [ 8.,  9., 10., 11.],
          [12., 13., 14., 15.]]]]) torch.Size([1, 1, 4, 4])
1
torch.Size([4, 1, 3, 3])

tensor([[[[ 0.,  1.,  2.],
          [ 4.,  5.,  6.],
          [ 8.,  9., 10.]]],


        [[[ 4.,  5.,  6.],
          [ 8.,  9., 10.],
          [12., 13., 14.]]],


        [[[ 1.,  2.,  3.],
          [ 5.,  6.,  7.],
          [ 9., 10., 11.]]],


        [[[ 5.,  6.,  7.],
          [ 9., 10., 11.],
          [13., 14., 15.]]]])

patching

prosti
  • 27,149
  • 7
  • 127
  • 118
  • Can you add the corresponding `.fold` operations to return to the original tensor? – Samuel Mar 01 '21 at 21:05
  • Do you mean can I restore the original tensor back after unfolding? – prosti Mar 02 '21 at 11:27
  • Yes, exactly :) Can you show how it is done? – Samuel Mar 02 '21 at 11:29
  • Check the [fold example](https://programming-review.com/pytorch/tensor#tensor-fold-and-unfold) – prosti Mar 08 '21 at 21:34
  • Wouldn't it be possible to get the same result with a single `F.unfold()` call by doing something like `F.unfold(input=x2d, kernel_size=(3, 3), dilation=(1, 1), stride=(1, 1), padding=(0, 0)`? – Samuel Mar 22 '21 at 18:46
8

unfold imagines a tensor as a longer tensor with repeated columns/rows of values 'folded' on top of each other, which is then "unfolded":

  • size determines how large the folds are
  • step determines how often it is folded

E.g. for a 2x5 tensor, unfolding it with step=1, and patch size=2 across dim=1:

x = torch.tensor([[1,2,3,4,5],
                  [6,7,8,9,10]])
>>> x.unfold(1,2,1)
tensor([[[ 1,  2], [ 2,  3], [ 3,  4], [ 4,  5]],
        [[ 6,  7], [ 7,  8], [ 8,  9], [ 9, 10]]])

enter image description here

fold is roughly the opposite of this operation, but "overlapping" values are summed in the output.

iacob
  • 7,935
  • 4
  • 26
  • 52