The torch.unflatten() method expands or reshapes a single dimension of a tensor into multiple dimensions. It is the inverse operation of torch.flatten() method, when applied to a single dimension, enables structured data restoration.

Syntax
torch.unflatten(tensor, dim, sizes)
Parameters
Argument | Description |
tensor (Tensor) |
It represents an input tensor that will be flattened. It must be of a 1D or higher-dimensional tensor. |
dim (int) | It is the dimension to flatten.
It can be a negative index (e.g., -1 for the last dimension). |
size (tuple of int or torch.Size) |
It is a tuple specifying the shape of the new dimensions that will replace the specified dimension. One thing to look at is that the product of the sizes must equal the size of the dimension being unflattened. |
Unflattening a 1D tensor to a 2D tensor

Let’s define a 1D tensor of nine elements using torch.arange() method and unflatten it to 3×3 2D tensor.
import torch tensor = torch.arange(9) print(tensor) # Output: tensor([0, 1, 2, 3, 4, 5, 6, 7, 8]) unflattened = torch.unflatten(tensor, dim=0, sizes=(3, 3)) print(unflattened) # Output: # tensor([[0, 1, 2], # [3, 4, 5], # [6, 7, 8]])
From the above code, you can see that the 1D tensor of size 9 is unflattened along dimension 0 into a 3×3 tensor.
The product of sizes (3 * 3 = 9) matches the size of the input tensor.
Negative dimension index
If you are working with a dynamic tensor process, you may need to use a negative index to unflatten the last dimension.
import torch # Creating a 2D tensor tensor = torch.arange(8).reshape(2, 4) print(tensor) # Output: # tensor([[0, 1, 2, 3], # [4, 5, 6, 7]]) unflattened = torch.unflatten(tensor, dim=-1, sizes=(2, 2)) print(unflattened) # Output: # tensor([[[0, 1], # [2, 3]], # [[4, 5], # [6, 7]]])
In this code, we have an input tensor of size (2, 4). That means the last dimension is 4.
Now, we have passed a negative index, which means we will select the last dimension, which is 4. We now want to unflatten it to a 2 x 2 matrix.
We now have a 2 x 2 x 2 dimension, which is equal to 8, and this is also equal to the product of the input tensor’s dimensions (2 x 4 = 8).
Unflattening into multiple dimensions
Let’s define a 1D tensor of 24 elements and unflatten that tensor by more than two dimensions. That means we are unflattening a 1D tensor into a multi-dimensional tensor.
import torch # Creating a 1D tensor tensor = torch.arange(24) unflattened = torch.unflatten(tensor, dim=0, sizes=(2, 3, 4)) print(unflattened.shape) # Output: torch.Size([2, 3, 4])
After unflattening, you can verify the size of the input and output tensors, which are both 24 (2 x 3 x 4).
Mismatched sizes

What if the product of sizes does not match the dimension size? Well, it will throw RuntimeError.
import torch tensor = torch.arange(6) try: unflattened = torch.unflatten(tensor, dim=0, sizes=(2, 4)) except RuntimeError as e: print(e) # unflatten: Provided sizes [2, 4] don't multiply up to the size of dim 0 (6) in the input tensor
We got the unflatten : provided sizes [2, 4] don’t multiply up to the size of dim 0 (6) in the input tensor error because multiplication of 2 x 4 = 8 and our input tensor has only 6 elements. So, there is a mismatch.
To avoid this type of RuntimeError, you need to make sure that the size of the input tensor and the size while unflattening a tensor remain the same.
Reversing a flattened tensor
As we previously discussed, to reverse the unflattening, we need to flatten, which is what we will do in this coding example.
import torch # Original tensor tensor = torch.arange(6).reshape(2, 3) print(tensor) # tensor([[0, 1, 2], # [3, 4, 5]]) flattened = torch.flatten(tensor) print(flattened) # Output: tensor([0, 1, 2, 3, 4, 5]) unflattened = torch.unflatten(flattened, dim=0, sizes=(2, 3)) print(unflattened) # Output: tensor([[0, 1, 2], # [3, 4, 5]])
First, we defined a 1D tensor using torch.arange() function and then reshaped it into a 2 x 3 matrix.
Then, we flattened that matrix by using torch.flatten() method into the 1D tensor.
Again, we unflatten that 1D tensor to a 2D tensor using torch.unflatten() method.