How does Pytorch’s “Fold” and “Unfold” work?

unfold imagines a tensor as a longer tensor with repeated columns/rows of values ‘folded’ on top of each other, which is then “unfolded”:

  • size determines how large the folds are
  • step determines how often it is folded

E.g. for a 2×5 tensor, unfolding it with step=1, and patch size=2 across dim=1:

x = torch.tensor([[1,2,3,4,5],
                  [6,7,8,9,10]])
>>> x.unfold(1,2,1)
tensor([[[ 1,  2], [ 2,  3], [ 3,  4], [ 4,  5]],
        [[ 6,  7], [ 7,  8], [ 8,  9], [ 9, 10]]])

enter image description here

fold is roughly the opposite of this operation, but “overlapping” values are summed in the output.

Leave a Comment