The torch.unbind() method removes the specified dimension from an input tensor, returning a tuple of tensors where each tensor is a slice along that dimension. It basically splits the tensor into individual tensors along a given axis (unlike split() or chunk()).

Syntax
torch.unbind(input, dim=0)
Parameters
Argument | Description |
input (Tensor) | It represents an input tensor to be unbound. |
dim (int, optional) | The default value is 0. It represents the dimension along which to unbind the tensor. |
Unbinding a 2D tensor along rows (dim=0)
If you pass dim=0, it will split the tensor along the rows. A 2D tensor is also referred to as a matrix, so we are dividing a matrix into rows.
import torch # Creating a 2D tensor (2x2) tensor = torch.tensor([[11, 19], [21, 18]]) print(tensor) # Output: tensor([[11, 19], # [21, 18]]) rows = torch.unbind(tensor, dim=0) print(rows) # Output: (tensor([11, 19]), tensor([21, 18])) # A tuple of two 1D tensors representing the rows of the original tensor.
In the above code, we define a 2×2 matrix, and the output is a tuple of two 1D tensors, which represent the rows of the matrix.
Unbinding a 2D tensor along columns (dim=1)

For unbinding along columns, we need to pass the dim=1 argument. It will return a tuple of 1D tensors that represent columns of the input matrix.
import torch # Creating a 2D tensor (2x2) tensor = torch.tensor([[11, 19], [21, 18]]) print(tensor) # Output: tensor([[11, 19], # [21, 18]]) columns = torch.unbind(tensor, dim=1) print(columns) # Output: (tensor([11, 21]), tensor([19, 18])) # A tuple of two 1D tensors representing the columns of the original tensor.
Since our input tensor is 2×2, the columns are also of size 2.
Unbinding a 3D Tensor
If you are working with a 3D tensor, we can also unbind based on a specific dimension.import torch # Creating a 3D tensor tensor = torch.tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]) # Shape: (2, 2, 2) print(tensor) # Output: tensor([[[1, 2], # [3, 4]], # [[5, 6], # [7, 8]]]) # Unbinding the tensor along dimension 0 slices = torch.unbind(tensor, dim=0) print(slices) # Output: (tensor([[1, 2], # [3, 4]]), # tensor([[5, 6], # [7, 8]]))
The above program demonstrates that we defined a 3D tensor (2x2x2) in size and are splitting it along dimension 0, producing a tuple of two 2D tensors, each with a shape of (2×2).
Negative dimension index

import torch # Creating a 2x3 tensor tensor = torch.tensor([[1, 2, 3], [4, 5, 6]]) # Shape: (2, 3) print(tensor) # Output: # tensor([[1, 2, 3], # [4, 5, 6]]) # Unbinding the tensor along dimension -1 result = torch.unbind(tensor, dim=-1) print(result) # Output: (tensor([1, 4]), tensor([2, 5]), tensor([3, 6]))
Comparison with torch.stack()
The torch.unbind() is a reverse operation of torch.stack() method. Let me demonstrate it.
import torch # Stack tensors t1 = torch.tensor([1, 2]) t2 = torch.tensor([3, 4]) stacked = torch.stack([t1, t2], dim=0) # Shape: (2, 2) print(stacked) # Output: tensor([[1, 2], # [3, 4]]) unbound = torch.unbind(stacked, dim=0) print(unbound) # Output: (tensor([1, 2]), tensor([3, 4]))
In the above code, you can see that the .stack() method combines tensors along a new dimension, and the .unbind() method reverses this by splitting along that dimension.
That’s all!