VIT part 1: Patchify Images using PyTorch Unfold

Mriganka Nath
4 min readJan 5, 2024

--

In this tutorial, we will focus on the first step of the Vision Transformer, which involves converting images into patches. To achieve this, we will utilize PyTorch’s built-in technique called ‘Unfold.’ We will also explain how Unfold works and demonstrate its application in the context of “patchifying” an image.

Vision Transformer (https://arxiv.org/abs/2010.11929)

The Unfold function in PyTorch enables access to specific parts of a tensor, allowing for further processing. It extracts blocks from a tensor in a sliding manner. Think of it as similar to the max-pooling or average-pooling operation, where a specified block size is processed, and then the operation slides to the next block. Unfold provides the ability to extract values from these sliding blocks. Additionally, the Unfold operation flattens the values within each block. Let’s illustrate this with an example.

For simplicity, let’s consider a tensor with a shape of [1 x 1 x 3 x 3], where the first dimension is the batch size, the second is the number of channels, and the last two are the height and width of the tensor. If we define the unfold operation with a kernel size of (2,2), a block of this size will slide through the tensor, extracting values, flattening them, and moving to the next position. Initially, it will extract values [1, 2, 4, 5], then slide to [2, 3, 5, 6], and continue this process by sliding down. The resultant tensor would have a shape of [1 x 4 x 4], where 1 is the batch size, 4 represents (patch size * number of channels), and the last dimension indicates the number of blocks it has slid through. The exact formula for the output tensor is provided in the documentation (which I will link at the end). The unfold operation in PyTorch is defined as:

torch.nn.Unfold(kernel_size, dilation=1, padding=0, stride=1)

The parameters are similar to what we set in pooling/conv operations in PyTorch.

Now onto how to use Unfold to make patches in our image. Let us define a class for that

class Patchify(nn.Module):
def __init__(self, patch_size=56):
super().__init__()
self.p = patch_size
self.unfold = torch.nn.Unfold(kernel_size=patch_size, stride=patch_size)

def forward(self, x):
# x -> B c h w
bs, c, h, w = x.shape

x = self.unfold(x)
# x -> B (c*p*p) L

# Reshaping into the shape we want
a = x.view(bs, c, self.p, self.p, -1).permute(0, 4, 1, 2, 3)
# a -> ( B no.of patches c p p )
return a

To obtain non-overlapping patches of the images, we need to ensure that there is no overlap while sliding. Therefore, we set the stride equal to the patch size so that while ‘unfolding,’ the block slides to a new position without any overlap. Once this operation is complete, we can reshape the tensor as desired using the ‘view’ function. Additionally, we can permute the channels of the resultant tensor to facilitate integration with next operations discussed later. In the end, we obtain a tensor of shape: [Batch_size x No. of patches x No. of Channels x patch height x patch width].

We can plot it and see the results.

First, let us get an image and use our Patchify class. The image we will be using is of shape 224 x 224 x 3.

patch = Patchify()
img_src = image_path
image = cv2.imread(img_src)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
image = image.astype('float32') / 255.0 # Normalize to [0, 1]
image = torch.from_numpy(image)
image = image.permute(2,0,1)
image = image.unsqueeze(0) #to add the batch dimension
p = patch(image)
p = p.squeeze() #to remove the batch dimension for plotting

A function to plot the image (with patches)

def plot_patches(tensor):
fig = plt.figure(figsize=(8, 8))
grid = ImageGrid(fig, 111, nrows_ncols=(4, 4), axes_pad=0.1)

for i, ax in enumerate(grid):
patch = tensor[i].permute(1, 2, 0).numpy()
ax.imshow(patch)
ax.axis('off')

plt.show()

plot_patches(p)

And finally, we can see the patches.

This marks the first step in our Vision Transformer (VIT) process. Next, we will explore how these patches are fed into our Transformer to generate output results!

In conclusion, similar to NLP transformers, Vision Transformer models use “tokens” created by patchifying images. This involves breaking down the image into patches, projecting them, and employing them as tokens in transformer layers. Patchification is a crucial process as it determines how the Vision Transformer (VIT) perceives the input. Ongoing research aims to enhance this process for improved results.

Further Links

Nice video to understand unfold and fold and how to use them for Convolution.

--

--

Mriganka Nath
Mriganka Nath

Written by Mriganka Nath

high dimensions go brrrrr; I work with Neural Networks;

Responses (2)