I am trying to create a transform
that shuffles the patches of each image in a batch.
I aim to use it in the same manner as the rest of the transformations in torchvision
:
trans = transforms.Compose([
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
ShufflePatches(patch_size=(16,16)) # our new transform
])
More specifically, the input is a BxCxHxW
tensor. I want to split each image in the batch into non-overlapping patches of size patch_size, shuffle them, and regroup into a single image.
Given the image (of size 224x224
):
Using ShufflePatches(patch_size=(112,112))
I would like to produce the output image:
I think the solution has to do with torch.unfold
and torch.fold
, but didn't manage to get any further.
Any help would be appreciated!