The torch.roll() method performs a circular shift of elements in a tensor along specified dimensions by the given number of positions.

Syntax
torch.roll(input, shifts, dims=None)
Parameters
Argument | Description |
input (Tensor) | It is an input tensor whose elements will be shifted. |
shifts (int or tuple of ints) |
It is the number of positions to shift elements. Positive values shift right/down. Negative values shift left/up. |
dims (int or tuple of ints) | It represents the dimension(s) along which to roll.
If None, the tensor is flattened, rolled, and reshaped back. |
Rolling a 1D Tensor

Let’s define a 1D tensor, and we shift the elements by one position.
import torch tensor = torch.tensor([1, 21, 31, 4, 5]) print(tensor) # Output: tensor([1, 21, 31, 4, 5]) rolled = torch.roll(tensor, shifts=1) print(rolled) # Output: tensor([ 5, 1, 21, 31, 4])
In the above program, you can see that a circular shift happened. That’s why element 1’s position is not second, and element 5’s position is now first.
If we want to shift elements by two positions, we need to pass shifts=2 as an argument.
import torch tensor = torch.tensor([1, 21, 31, 4, 5]) print(tensor) # Output: tensor([1, 21, 31, 4, 5]) rolled = torch.roll(tensor, shifts=2) print(rolled) # Output: tensor([ 4, 5, 1, 21, 31]) # Shifted by 2 positions to the right
Negative shifts

If you want to move elements in the opposite direction, you need to pass negative shifts.
So, if a positive shift moves elements to the right side, the negative shift moves elements to the left side of the tensor.
import torch tensor = torch.tensor([1, 21, 31, 4, 5]) print(tensor) # Output: tensor([1, 21, 31, 4, 5]) rolled = torch.roll(tensor, shifts=-2) print(rolled) # Output: tensor([31, 4, 5, 1, 21])
In the above code, 31 was initially in the middle when we initialized a tensor, but after passing a negative shift, element 31 shifted to the first position from the third because the shifting is now happening on the left side of the tensor. Each element is moved to the left by two positions.
Rolling a 2D tensor across rows
To roll elements in a 2D tensor along rows, we need to pass the “dims=0” argument.
import torch tensor = torch.tensor([[11, 2, 31], [4, 51, 6]]) print(tensor) # Output: # tensor([[11, 2, 31], # [ 4, 51, 6]]) rolled_2d = torch.roll(tensor, shifts=1, dims=0) # Shift down along dim 0 (rows) print(rolled_2d) # Output: # tensor([[ 4, 51, 6], # [11, 2, 31]])
Here, dim=0 means shifting the row by a specific position, which is just 1. Therefore, if the matrix has only two rows, the first row becomes the last, and the last row becomes the first.
That’s how shifting works in a 2D tensor or matrix.
Shifting elements of a 2D tensor across columns

import torch tensor = torch.tensor([[11, 2, 31], [4, 51, 6]]) print(tensor) # Output: # tensor([[11, 2, 31], # [ 4, 51, 6]]) rolled_columns = torch.roll(tensor, shifts=2, dims=1) # Shift right along dim 1 (columns) print(rolled_columns) # Output: # tensor([[ 2, 31, 11], # [51, 6, 4]])
In the above code, we are shifting the elements along the columns. Meaning, it will shift the element to the right side of the tensor. We pass the shifts=2 argument, which means it shifts two elements to the right side in each row.
Rolling across multiple dimensions
You can apply different shifts to multiple dimensions simultaneously. Pass the tuple of dimensions.import torch tensor = torch.tensor([[11, 2, 31], [4, 51, 6]]) print(tensor) # Output: # tensor([[11, 2, 31], # [ 4, 51, 6]]) rolled_multi_dimension = torch.roll(tensor, shifts=(1, 2), dims=(0, 1)) # Shift down 1, right 2 print(rolled_multi_dimension) # Output: # tensor([[51, 6, 4], # [ 2, 31, 11]])
In this code, rows are shifted down by 1 (dimension 0), and columns are shifted right by 2 (dimension 1), with wrapping occurring at boundaries.