Sprint Chase Technologies
  • Home
  • About
    • Why Choose Us
    • Contact Us
    • Team Members
    • Testimonials
  • Services
    • Web Development
    • Web Application Development
    • Mobile Application Development
    • Web Design
    • UI/UX Design
    • Social Media Marketing
    • Projects
  • Blog
    • PyTorch
    • Python
    • JavaScript
  • IT Institute
menu
close

Need Help? Talk to an Expert

+91 8000107255
Sprint Chase Technologies
  • Home
  • About
    • Why Choose Us
    • Contact Us
    • Team Members
    • Testimonials
  • Services
    • Web Development
    • Web Application Development
    • Mobile Application Development
    • Web Design
    • UI/UX Design
    • Social Media Marketing
    • Projects
  • Blog
    • PyTorch
    • Python
    • JavaScript
  • IT Institute

Need Help? Talk to an Expert

+91 8000107255

torch.index_select() Method in PyTorch

Home torch.index_select() Method in PyTorch
PyTorch torch.index_select() Method
  • Written by krunallathiya21
  • June 13, 2025
  • 0 Com
PyTorch

The torch.index_select() method extracts specific elements (rows, columns, or other slices) from the input tensor along a specified dimension based on a tensor of indices.

torch.index_select()

Syntax

torch.index_select(input, dim, index, out)

Parameters

Argument Description
input (Tensor) It is an input tensor from which the elements are selected.
dim (int) It is a dimension along with an index. If dim = 0, it will select rows from the tensor. If dim = 1, it will select columns from the tensor.
index (LongTensor)

It is a 1D tensor containing the indices of elements to choose.

Indices must be non-negative and within the bounds of the specified dimension.

out (Tensor, optional) It represents the output tensor to store the result.

Selecting specific rows from a 2D tensor

First, we will define a tensor of indices. Those indices in the tensor represent which rows we want to select.

Also, we are selecting along the rows; we need to pass dim = 0.

import torch

# Input tensor (3x3)
input_tensor = torch.tensor([[11, 21, 31],
                             [51, 61, 71],
                             [91, 101, 111]])

# Indices to select (rows 0 and 2)
row_indices = torch.tensor([0, 2])

# Select rows along dim=0
selected_matrix = torch.index_select(input_tensor,
                                     dim=0,
                                     index=row_indices)

print(selected_matrix)

# Output:
# tensor([[ 11,  21,  31],
#         [ 91, 101, 111]])

We are selecting row index 0 and 2. So, the output tensor contains only the first and third rows.

Selecting specific columns from a 2D tensor

Selecting a column from a 2D tensor

To select columns, we need to pass dim=1 along with the column indices.

import torch

matrix = torch.tensor([[11, 21, 31],
                       [51, 61, 71],
                       [91, 101, 111]])

# Indices to select (second column)
column_indices = torch.tensor([1])

# Selecting a column along dim=1
column_matrix = torch.index_select(matrix,
                                   dim=1,
                                   index=column_indices)

print(column_matrix)

# Output:
# tensor([[ 21],
#         [ 61],
#         [101]])

We passed only a single index, which is 1, that will return only the second column of the matrix.

Selecting elements from a 1D Tensor

Selecting elements from a 1D Tensor

If the input tensor is 1D, there are no rows and columns, just elements. In that case, we need to select the elements of the tensor.

import torch

tensor_1d = torch.tensor([101, 201, 301, 401, 501, 601])

# Indices to select
indices_of_elements = torch.tensor([1, 3, 5])

# Select elements along dim=0
selected_elements = torch.index_select(
    tensor_1d, dim=0,
    index=indices_of_elements)

print(selected_elements)
# Output: tensor([201, 401, 601])

Here, we are selecting elements at index 1, which corresponds to 201; index 3, which corresponds to element 401; and index 5, which corresponds to element 601. The output tensor contains these values.

3D tensor selection

Let’s select the first block of the 3D tensor.

import torch

tensor_3d = torch.arange(8).reshape(2, 2, 2)

indices = torch.tensor([0])

first_channel = torch.index_select(tensor_3d, dim=0, index=indices)

print(first_channel)
# Output:
# tensor([[[0, 1],
#          [2, 3]]])

In a 3D tensor, dim=0 corresponds to the outermost “layer” (the “blocks”). So we are extracting just the first block. That’s all!

Post Views: 1
LEAVE A COMMENT Cancel reply
Please Enter Your Comments *

krunallathiya21

All Categories
  • JavaScript
  • Python
  • PyTorch
site logo

Address:  TwinStar, South Block – 1202, 150 Ft Ring Road, Nr. Nana Mauva Circle, Rajkot(360005), Gujarat, India

sprintchasetechnologies@gmail.com

(+91) 8000107255.

ABOUT US
  • About
  • Team Members
  • Testimonials
  • Contact

Copyright by @SprintChase  All Rights Reserved

  • PRIVACY
  • TERMS & CONDITIONS