The torch.index_select() method extracts specific elements (rows, columns, or other slices) from the input tensor along a specified dimension based on a tensor of indices.
Syntax
torch.index_select(input, dim, index, out)
Parameters
| Argument | Description |
| input (Tensor) | It is an input tensor from which the elements are selected. |
| dim (int) | It is a dimension along with an index. If dim = 0, it will select rows from the tensor. If dim = 1, it will select columns from the tensor. |
| index (LongTensor) |
It is a 1D tensor containing the indices of elements to choose. Indices must be non-negative and within the bounds of the specified dimension. |
| out (Tensor, optional) | It represents the output tensor to store the result. |
Selecting specific rows from a 2D tensor
First, we will define a tensor of indices. Those indices in the tensor represent which rows we want to select.
Also, we are selecting along the rows; we need to pass dim = 0.
import torch
# Input tensor (3x3)
input_tensor = torch.tensor([[11, 21, 31],
[51, 61, 71],
[91, 101, 111]])
# Indices to select (rows 0 and 2)
row_indices = torch.tensor([0, 2])
# Select rows along dim=0
selected_matrix = torch.index_select(input_tensor,
dim=0,
index=row_indices)
print(selected_matrix)
# Output:
# tensor([[ 11, 21, 31],
# [ 91, 101, 111]])
We are selecting row index 0 and 2. So, the output tensor contains only the first and third rows.
Selecting specific columns from a 2D tensor
To select columns, we need to pass dim=1 along with the column indices.
import torch
matrix = torch.tensor([[11, 21, 31],
[51, 61, 71],
[91, 101, 111]])
# Indices to select (second column)
column_indices = torch.tensor([1])
# Selecting a column along dim=1
column_matrix = torch.index_select(matrix,
dim=1,
index=column_indices)
print(column_matrix)
# Output:
# tensor([[ 21],
# [ 61],
# [101]])
We passed only a single index, which is 1, that will return only the second column of the matrix.
Selecting elements from a 1D Tensor
If the input tensor is 1D, there are no rows and columns, just elements. In that case, we need to select the elements of the tensor.
import torch
tensor_1d = torch.tensor([101, 201, 301, 401, 501, 601])
# Indices to select
indices_of_elements = torch.tensor([1, 3, 5])
# Select elements along dim=0
selected_elements = torch.index_select(
tensor_1d, dim=0,
index=indices_of_elements)
print(selected_elements)
# Output: tensor([201, 401, 601])
Here, we are selecting elements at index 1, which corresponds to 201; index 3, which corresponds to element 401; and index 5, which corresponds to element 601. The output tensor contains these values.
3D tensor selection
Let’s select the first block of the 3D tensor.
import torch tensor_3d = torch.arange(8).reshape(2, 2, 2) indices = torch.tensor([0]) first_channel = torch.index_select(tensor_3d, dim=0, index=indices) print(first_channel) # Output: # tensor([[[0, 1], # [2, 3]]])
In a 3D tensor, dim=0 corresponds to the outermost “layer” (the “blocks”). So we are extracting just the first block. That’s all!
