You can try these:
Note: The below functions takes a 2D tensor as input. If your tensor A
is of shape (1, N, N) i.e., has a (redundant) batch/channel dimension, pass A.squeeze()
to func()
.
Method 1:
This method broadcasted multiplication followed by transpose and reshape operations to achieve the final result.
import torch
import torch.nn as nn
A = torch.tensor([[0, 1, 1], [1, 0, 1], [1, 1, 0]])
K = 3
def func(A, K):
ones = torch.ones(K, K)
tmp = ones.unsqueeze(0) * A.view(-1, 1, 1)
tmp = tmp.reshape(A.shape[0], A.shape[1], K, K)
res = tmp.transpose(1, 2).reshape(K * A.shape[0], K * A.shape[1])
return res
Method 2:
From @Shai's hint in comments, this method repeats the (2D) tensor in channel dimension K**2
times and then uses PixelShuffle() to upscale the row and column by K
times.
def pixelshuffle(A, K):
pixel_shuffle = nn.PixelShuffle(K)
return pixel_shuffle(A.unsqueeze(0).repeat(K**2, 1, 1).unsqueeze(0)).squeeze(0).squeeze(0)
Since nn.PixelShuffle()
takes only 4D tensors as input, unsqueezing after the repeat()
was necessary. Also note, since the returned tensor from nn.PixelShuffle()
is also 4D, the two squeeze()
s followed to ensure we get a 2D tensor as output.
Some example outputs:
A = torch.tensor([[0, 1], [1, 0]])
func(A, 2)
# tensor([[0., 0., 1., 1.],
# [0., 0., 1., 1.],
# [1., 1., 0., 0.],
# [1., 1., 0., 0.]])
pixelshuffle(A, 2)
# tensor([[0, 0, 1, 1],
# [0, 0, 1, 1],
# [1, 1, 0, 0],
# [1, 1, 0, 0]])
Feel free to ask for further clarifications and let me know if it works for you.
Benchmarking:
I benchmarked my answers func()
and pixel shuffle()
against @iacob's dilate()
function above and found that mine are slightly faster.
A = torch.randint(3, 100, (20, 20))
assert (dilate(A, 5) == func(A, 5)).all()
assert (dilate(A, 5) == pixelshuffle(A, 5)).all()
%timeit dilate(A, 5)
# 142 µs ± 2.54 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
%timeit func(A, 5)
# 57.9 µs ± 1.67 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
%timeit pixelshuffle(A, 5)
# 81.6 µs ± 970 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)