4

I am trying to create a transform that shuffles the patches of each image in a batch. I aim to use it in the same manner as the rest of the transformations in torchvision:

trans = transforms.Compose([
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
            ShufflePatches(patch_size=(16,16)) # our new transform
        ])

More specifically, the input is a BxCxHxW tensor. I want to split each image in the batch into non-overlapping patches of size patch_size, shuffle them, and regroup into a single image.

Given the image (of size 224x224):

enter image description here

Using ShufflePatches(patch_size=(112,112)) I would like to produce the output image:

enter image description here

I think the solution has to do with torch.unfold and torch.fold, but didn't manage to get any further.

Any help would be appreciated!

Shir
  • 1,229
  • 4
  • 19

1 Answers1

3

Indeed unfold and fold seem appropriate in this case.

import torch
import torch.nn.functional as nnf

class ShufflePatches(object):
  def __init__(self, patch_size):
    self.ps = patch_size

  def __call__(self, x):
    # divide the batch of images into non-overlapping patches
    u = nnf.unfold(x, kernel_size=self.ps, stride=self.ps, padding=0)
    # permute the patches of each image in the batch
    pu = torch.cat([b_[:, torch.randperm(b_.shape[-1])][None,...] for b_ in u], dim=0)
    # fold the permuted patches back together
    f = nnf.fold(pu, x.shape[-2:], kernel_size=self.ps, stride=self.ps, padding=0)
    return f

Here's an example with patch size=16:
enter image description here

Shir
  • 1,229
  • 4
  • 19
Shai
  • 93,148
  • 34
  • 197
  • 325