The torch.index_select() method extracts specific elements (rows, columns, or other slices) from the input tensor along a specified dimension based on a tensor of indices.

Syntax
torch.index_select(input, dim, index, out)
Parameters
Argument | Description |
input (Tensor) | It is an input tensor from which the elements are selected. |
dim (int) | It is a dimension along with an index. If dim = 0, it will select rows from the tensor. If dim = 1, it will select columns from the tensor. |
index (LongTensor) |
It is a 1D tensor containing the indices of elements to choose. Indices must be non-negative and within the bounds of the specified dimension. |
out (Tensor, optional) | It represents the output tensor to store the result. |
Selecting specific rows from a 2D tensor
First, we will define a tensor of indices. Those indices in the tensor represent which rows we want to select.
Also, we are selecting along the rows; we need to pass dim = 0.
import torch # Input tensor (3x3) input_tensor = torch.tensor([[11, 21, 31], [51, 61, 71], [91, 101, 111]]) # Indices to select (rows 0 and 2) row_indices = torch.tensor([0, 2]) # Select rows along dim=0 selected_matrix = torch.index_select(input_tensor, dim=0, index=row_indices) print(selected_matrix) # Output: # tensor([[ 11, 21, 31], # [ 91, 101, 111]])
We are selecting row index 0 and 2. So, the output tensor contains only the first and third rows.
Selecting specific columns from a 2D tensor

To select columns, we need to pass dim=1 along with the column indices.
import torch matrix = torch.tensor([[11, 21, 31], [51, 61, 71], [91, 101, 111]]) # Indices to select (second column) column_indices = torch.tensor([1]) # Selecting a column along dim=1 column_matrix = torch.index_select(matrix, dim=1, index=column_indices) print(column_matrix) # Output: # tensor([[ 21], # [ 61], # [101]])
We passed only a single index, which is 1, that will return only the second column of the matrix.
Selecting elements from a 1D Tensor

If the input tensor is 1D, there are no rows and columns, just elements. In that case, we need to select the elements of the tensor.
import torch tensor_1d = torch.tensor([101, 201, 301, 401, 501, 601]) # Indices to select indices_of_elements = torch.tensor([1, 3, 5]) # Select elements along dim=0 selected_elements = torch.index_select( tensor_1d, dim=0, index=indices_of_elements) print(selected_elements) # Output: tensor([201, 401, 601])
Here, we are selecting elements at index 1, which corresponds to 201; index 3, which corresponds to element 401; and index 5, which corresponds to element 601. The output tensor contains these values.
3D tensor selection
Let’s select the first block of the 3D tensor.
import torch tensor_3d = torch.arange(8).reshape(2, 2, 2) indices = torch.tensor([0]) first_channel = torch.index_select(tensor_3d, dim=0, index=indices) print(first_channel) # Output: # tensor([[[0, 1], # [2, 3]]])
In a 3D tensor, dim=0 corresponds to the outermost “layer” (the “blocks”). So we are extracting just the first block. That’s all!