Sprint Chase Technologies
  • Home
  • About
    • Why Choose Us
    • Contact Us
    • Team Members
    • Testimonials
  • Services
    • Web Development
    • Web Application Development
    • Mobile Application Development
    • Web Design
    • UI/UX Design
    • Social Media Marketing
    • Projects
  • Blog
    • PyTorch
    • Python
    • JavaScript
  • IT Institute
menu
close

Need Help? Talk to an Expert

+91 8000107255
Sprint Chase Technologies
  • Home
  • About
    • Why Choose Us
    • Contact Us
    • Team Members
    • Testimonials
  • Services
    • Web Development
    • Web Application Development
    • Mobile Application Development
    • Web Design
    • UI/UX Design
    • Social Media Marketing
    • Projects
  • Blog
    • PyTorch
    • Python
    • JavaScript
  • IT Institute

Need Help? Talk to an Expert

+91 8000107255

torch.argmax(): Find the Index of Max Value in a Tensor

Home torch.argmax(): Find the Index of Max Value in a Tensor
PyTorch torch.argmax() Method
  • Written by krunallathiya21
  • May 12, 2025
  • 0 Com
PyTorch

The torch.argmax() method returns the indices of the maximum values in the input tensor along the specified dimension or across all dimensions. The return type is torch.int64, a data type of indices.

torch.argmax()

If the tensor has multiple maximum values, it will return the index of the first maximum element.

This operation is not differentiable; avoid using it in gradient computations.

The main difference between torch.max() and torch.argmax() is that the .max() method returns the values and indices (optionally), while .argmax() method only returns indices of maximum values.

It is equivalent to a torch.max(input, dim=dim, keepdim=keepdim)[1].

Function signature

torch.argmax(input)

# OR

torch.argmax(input, dim, keepdim=False)

Parameters

Argument Description
input (Tensor) It represents an input tensor whose max value’s indices we need to find.
dim (int, optional) It defines a dimension to reduce. If None, it will flatten the tensor.
keepdim (bool, optional)

If True, output retains reduced dimension with size 1. Default: False.

Basic Usage on 1D Tensor

In PyTorch, a tensor element’s indexing starts with 0. So, the first element has the 0th index, the second has the 1st, and so on.

import torch

tensor = torch.tensor([11, 21, 19, 10])

max_element_id = torch.argmax(tensor)

print(max_element_id)  

# Output: tensor(1)

You can see that element 21 is the maximum element and the second element. So, its index is 1. And we get the output tensor(1).

Ties in Maximum Values

tie in multiple maximum values

What if our tensor contains multiple maximum values? In that case, it will return the index of the first occurrence.

import torch

tensor = torch.tensor([21, 19, 21, 10])

max_element_id = torch.argmax(tensor)

print(max_element_id)

# Output: tensor(0)

Element 21 is the maximum value in the above tensor, and it occurred twice: once at position 0th and again at position 2. The output is tensor(0), which returns the index of the first occurrence.

Flattening a 2D tensor

If you are only interested in the global maximum value’s index, regardless of its shape, you can flatten an output tensor.

Let’s find the index of the maximum value across the 2D tensor.

import torch

tensor_2d = torch.tensor([[1, 5, 2],
                          [3, 2, 9]])

max_index_across_2d_tensor = torch.argmax(tensor_2d)

print(max_index_across_2d_tensor)

# Output: tensor(5)

In the above code, element 9 is the max value and the last element of the tensor. Its index is 5.

You can see that we flattened an output tensor.

The returned index corresponds to this 1D view, not the original 2D coordinates.

2D Tensor with dim=0 (column-wise)

If the input tensor is 2D, we can use the “dim” argument to find the maximum value index based on row-wise or column-wise. If you want to find column-wise, pass dim = 0.

import torch

tensor_2d = torch.tensor([[1, 5, 2],
                          [3, 2, 9]])

column_wise_id = torch.argmax(tensor_2d, dim=0)

print(column_wise_id)

# Output: tensor([1, 0, 1])

The first index is 1, which means that between elements 1 and 3, 3 is maximum, and its index is 1.

The second index is 0, which means number 5 is greater than 2. The element 5’s index is 0.

The third index is 1, which means number 9 is greater than 2.

2D Tensor with dim=1 (row-wise)

Pass dim=1 to get the indices of the max value row-wise.
import torch

tensor_2d = torch.tensor([[1, 5, 2],
                          [3, 2, 9]])

row_wise_id = torch.argmax(tensor_2d, dim=1)

print(row_wise_id)

# Output: tensor([1, 2])

In the first row of the tensor, element 5 is maximum and its index is 1.

In the second row of the tensor, an element 9 is the maximum value, and its index is 2.

With keepdim=True

If you pass the keepdim = True, it maintains the original dimensionality for broadcasting or reshaping purposes.

import torch

tensor = torch.tensor([[1, 2, 3],
                       [11, 21, 19]])

reduced_index = torch.argmax(tensor, dim=1, keepdim=True)

print(reduced_index)

# Output:
# tensor([[2],
#         [1]])

You can see that the output retains the row dimension (shape [2, 1]).

import torch

tensor_2dim = torch.tensor([[1, 4], [3, 2]])

idx = torch.argmax(tensor_2dim)

print(idx)

# Output: tensor(1)
# The output is equivalent to 1D: [1, 4, 3, 2]
Since dim is not specified, PyTorch flattens the tensor to 1D internally

Higher-Dimensional Tensors

We can use the torch.randn() method to create a 4D tensor, and then we will apply the argmax() method to find the index of the maximum value along the width (4 elements) for each [batch, channel, height] position.

import torch

tensor_4d = torch.randn(1, 2, 3, 4)

print(tensor_4d)
# Output:
# tensor([[[[-2.0140, -0.9451, -0.9034,  0.3920],
#           [ 0.2874, -0.8883,  0.6729,  2.2473],
#           [ 0.3683, -0.7728, -1.6260,  0.9894]],

#          [[-1.7922,  1.0030,  0.9449, -0.4141],
#           [-0.2690, -0.5286,  1.1329,  1.4287],
#           [-0.7494, -0.4388, -0.3823, -0.9693]]]])

print(tensor_4d.shape)
# Output: torch.Size([1, 2, 3, 4])

argmax_dim3 = torch.argmax(tensor_4d, dim=3, keepdim=False)

print(argmax_dim3)
# Output:
#  tensor([[[3, 3, 3],
#          [1, 3, 2]]])

print(argmax_dim3.shape)
# Output: torch.Size([1, 2, 3])
That’s all!
Post Views: 8
LEAVE A COMMENT Cancel reply
Please Enter Your Comments *

krunallathiya21

All Categories
  • JavaScript
  • Python
  • PyTorch
site logo

Address:  TwinStar, South Block – 1202, 150 Ft Ring Road, Nr. Nana Mauva Circle, Rajkot(360005), Gujarat, India

sprintchasetechnologies@gmail.com

(+91) 8000107255.

ABOUT US
  • About
  • Team Members
  • Testimonials
  • Contact

Copyright by @SprintChase  All Rights Reserved

  • PRIVACY
  • TERMS & CONDITIONS