The torch.flip() method reverses the order of elements in a tensor along specified dimensions and returns a new tensor. It supports flipping along one or multiple dimensions simultaneously.
Syntax
torch.flip(input, dims)
Parameters
| Argument | Description |
| input (Tensor) | It is the input tensor that needs to be flipped. |
| dims (int or tuple of ints) |
It is the dimension(s) along which to flip the tensor. It can be a single integer or a tuple of integers. Negative indices are supported (e.g., -1 refers to the last dimension). |
Flipping a 1D Tensor
Let’s reverse the order of the elements in a 1D tensor.
import torch tensor = torch.tensor([1, 12, 3, 41, 5]) print(tensor) # Output: tensor([ 1, 12, 3, 41, 5]) flipped = torch.flip(tensor, dims=[0]) print(flipped) # Output: tensor([ 5, 41, 3, 12, 1])
As shown in the above code, the output contains the same elements but in a different order. The order is the reverse of the input tensor.
Flipping a 2D Tensor along rows and columns
When it comes to a 2D tensor, we can flip the tensor along the rows and columns. PyTorch already provides an option to choose using the “dim” argument.
If you want to flip based on rows, pass dim = [0], and if you want to flip along with columns, pass dim = [1].
import torch
tensor = torch.tensor([[1, 2, 3],
[14, 15, 16]])
flipped_rows = torch.flip(tensor, dims=[0])
print(flipped_rows)
# Output: tensor([[14, 15, 16],
# [ 1, 2, 3]])
flipped_cols = torch.flip(tensor, dims=[1])
print(flipped_cols)
# Output: tensor([[ 3, 2, 1],
# [16, 15, 14]])
You can see from the above program that dims=[0] flips rows (vertical flip), swapping row 0 with row 1, and dims=[1] flips rows (horizontal flip), reversing elements within each row.
Flipping along multiple dimensions
We can also pass both dimensions together to flip, which is equivalent to a 180-degree rotation of the tensor.
import torch
tensor = torch.tensor([[11, 21, 31],
[14, 15, 16]])
print(tensor)
# Output: tensor([[11, 21, 31],
# [14, 15, 16]])
flipped_both = torch.flip(tensor, dims=[0, 1])
print(flipped_both)
# Output: tensor([[16, 15, 14],
# [31, 21, 11]])
The dims=[0, 1] first reverses rows, then reverses columns of the input tensor.
Flipping a 3D Tensor
We can flip the 3D tensor along one or more dimensions.
import torch
tensor = torch.tensor([[[1, 2],
[3, 4]],
[[5, 6],
[7, 8]]])
print(tensor)
# Output:
# tensor([[[1, 2],
# [3, 4]],
# [[5, 6],
# [7, 8]]])
flipped_3d = torch.flip(tensor, dims=[2])
print(flipped_3d)
# Output:
# tensor([[[2, 1],
# [4, 3]],
# [[6, 5],
# [8, 7]]])
We flipped along dim=2, reversing elements in the innermost dimension (width), which affects each 2×2 matrix.
Negative Indices
For a 2D tensor with shape (2, 3), n=2. The index -1 corresponds to dim=1 (columns). Flipping along dim=1 reverses the columns in each row.
import torch
tensor = torch.tensor([[1, 2, 3],
[4, 5, 6]])
flipped = torch.flip(tensor, dims=[-1])
print(flipped)
# Output: tensor([[3, 2, 1],
# [6, 5, 4]])
Invalid negative index
If your input tensor is in 1D and you accidentally flip the columns, it will throw an exception because a 1D tensor does not contain columns.
import torch
tensor = torch.tensor([1, 2, 3])
try:
flipped = torch.flip(tensor, dims=[-2])
except IndexError as e:
print(e)
# Output: IndexError: Dimension out of range (expected to be in range of [-1, 0], but got -2)
That error is this: IndexError: Dimension out of range (expected to be in range of [-1, 0], but got -2)
To fix this error, don’t pass -2 as an argument, since it is simply a 1D array.
Repeating dimensions
If you repeat the dimension, it will also throw the RuntimeError: dim 0 appears multiple times in the list of dims.
import torch
tensor = torch.tensor([1, 2, 3])
try:
flipped = torch.flip(tensor, dims=[0, 0])
except RuntimeError as e:
print(e)
# Output: dim 0 appears multiple times in the list of dims
To prevent this type of error, ensure that you don’t repeat the dimension while flipping a tensor.
That’s all!
