Sprint Chase Technologies
  • Home
  • About
    • Why Choose Us
    • Contact Us
    • Team Members
    • Testimonials
  • Services
    • Web Development
    • Web Application Development
    • Mobile Application Development
    • Web Design
    • UI/UX Design
    • Social Media Marketing
    • Projects
  • Blog
    • PyTorch
    • Python
    • JavaScript
  • IT Institute
menu
close

Need Help? Talk to an Expert

+91 8000107255
Sprint Chase Technologies
  • Home
  • About
    • Why Choose Us
    • Contact Us
    • Team Members
    • Testimonials
  • Services
    • Web Development
    • Web Application Development
    • Mobile Application Development
    • Web Design
    • UI/UX Design
    • Social Media Marketing
    • Projects
  • Blog
    • PyTorch
    • Python
    • JavaScript
  • IT Institute

Need Help? Talk to an Expert

+91 8000107255

torch.unflatten(): Expand a Dimension of the Input Tensor

Home torch.unflatten(): Expand a Dimension of the Input Tensor
PyTorch torch.unflatten() Method
  • Written by krunallathiya21
  • June 10, 2025
  • 0 Com
PyTorch

The torch.unflatten() method expands or reshapes a single dimension of a tensor into multiple dimensions. It is the inverse operation of torch.flatten() method, when applied to a single dimension, enables structured data restoration.

Unflattening a PyTorch Tensor

Syntax

torch.unflatten(tensor, dim, sizes)

Parameters

Argument Description
tensor (Tensor)

It represents an input tensor that will be flattened. 

It must be of a 1D or higher-dimensional tensor.

dim (int) It is the dimension to flatten.

It can be a negative index (e.g., -1 for the last dimension).

size (tuple of int or torch.Size)

It is a tuple specifying the shape of the new dimensions that will replace the specified dimension.

One thing to look at is that the product of the sizes must equal the size of the dimension being unflattened.

Unflattening a 1D tensor to a 2D tensor

Unflattening a 1D tensor to a 2D tensor

Let’s define a 1D tensor of nine elements using torch.arange() method and unflatten it to 3×3 2D tensor.

import torch

tensor = torch.arange(9)

print(tensor)
# Output: tensor([0, 1, 2, 3, 4, 5, 6, 7, 8])

unflattened = torch.unflatten(tensor, dim=0, sizes=(3, 3))

print(unflattened)
# Output:
# tensor([[0, 1, 2],
#         [3, 4, 5],
#         [6, 7, 8]])

From the above code, you can see that the 1D tensor of size 9 is unflattened along dimension 0 into a 3×3 tensor. 

The product of sizes (3 * 3 = 9) matches the size of the input tensor.

Negative dimension index

If you are working with a dynamic tensor process, you may need to use a negative index to unflatten the last dimension.

import torch

# Creating a 2D tensor
tensor = torch.arange(8).reshape(2, 4)

print(tensor)
# Output:
# tensor([[0, 1, 2, 3],
#         [4, 5, 6, 7]])

unflattened = torch.unflatten(tensor, dim=-1, sizes=(2, 2))

print(unflattened)
# Output:
# tensor([[[0, 1],
#          [2, 3]],

#         [[4, 5],
#          [6, 7]]])

In this code, we have an input tensor of size (2, 4). That means the last dimension is 4. 

Now, we have passed a negative index, which means we will select the last dimension, which is 4. We now want to unflatten it to a 2 x 2 matrix.

We now have a 2 x 2 x 2 dimension, which is equal to 8, and this is also equal to the product of the input tensor’s dimensions (2 x 4 = 8).

Unflattening into multiple dimensions

Let’s define a 1D tensor of 24 elements and unflatten that tensor by more than two dimensions. That means we are unflattening a 1D tensor into a multi-dimensional tensor.

import torch

# Creating a 1D tensor
tensor = torch.arange(24)

unflattened = torch.unflatten(tensor, dim=0, sizes=(2, 3, 4))

print(unflattened.shape)
# Output: torch.Size([2, 3, 4])

After unflattening, you can verify the size of the input and output tensors, which are both 24 (2 x 3 x 4).

Mismatched sizes

Mismatched size causes RuntimeError

What if the product of sizes does not match the dimension size? Well, it will throw RuntimeError.

import torch

tensor = torch.arange(6)
try:
    unflattened = torch.unflatten(tensor, dim=0, sizes=(2, 4))
except RuntimeError as e:
    print(e)

# unflatten: Provided sizes [2, 4] don't multiply up to the size of dim 0 (6) in the input tensor

We got the unflatten : provided sizes [2, 4] don’t multiply up to the size of dim 0 (6) in the input tensor error because multiplication of 2 x 4 = 8 and our input tensor has only 6 elements. So, there is a mismatch.

To avoid this type of RuntimeError, you need to make sure that the size of the input tensor and the size while unflattening a tensor remain the same.

Reversing a flattened tensor

As we previously discussed, to reverse the unflattening, we need to flatten, which is what we will do in this coding example.

import torch

# Original tensor
tensor = torch.arange(6).reshape(2, 3)

print(tensor)
# tensor([[0, 1, 2],
#         [3, 4, 5]])

flattened = torch.flatten(tensor)

print(flattened)
# Output: tensor([0, 1, 2, 3, 4, 5])

unflattened = torch.unflatten(flattened, dim=0, sizes=(2, 3))

print(unflattened)
# Output: tensor([[0, 1, 2],
#                 [3, 4, 5]])

First, we defined a 1D tensor using torch.arange() function and then reshaped it into a 2 x 3 matrix.

Then, we flattened that matrix by using torch.flatten() method into the 1D tensor.

Again, we unflatten that 1D tensor to a 2D tensor using torch.unflatten() method.

Post Views: 2
LEAVE A COMMENT Cancel reply
Please Enter Your Comments *

krunallathiya21

All Categories
  • JavaScript
  • Python
  • PyTorch
site logo

Address:  TwinStar, South Block – 1202, 150 Ft Ring Road, Nr. Nana Mauva Circle, Rajkot(360005), Gujarat, India

sprintchasetechnologies@gmail.com

(+91) 8000107255.

ABOUT US
  • About
  • Team Members
  • Testimonials
  • Contact

Copyright by @SprintChase  All Rights Reserved

  • PRIVACY
  • TERMS & CONDITIONS