Sprint Chase Technologies
  • Home
  • About
    • Why Choose Us
    • Contact Us
    • Team Members
    • Testimonials
  • Services
    • Web Development
    • Web Application Development
    • Mobile Application Development
    • Web Design
    • UI/UX Design
    • Social Media Marketing
    • Projects
  • Blog
    • PyTorch
    • Python
    • JavaScript
  • IT Institute
menu
close

Need Help? Talk to an Expert

+91 8000107255
Sprint Chase Technologies
  • Home
  • About
    • Why Choose Us
    • Contact Us
    • Team Members
    • Testimonials
  • Services
    • Web Development
    • Web Application Development
    • Mobile Application Development
    • Web Design
    • UI/UX Design
    • Social Media Marketing
    • Projects
  • Blog
    • PyTorch
    • Python
    • JavaScript
  • IT Institute

Need Help? Talk to an Expert

+91 8000107255

torch.max(): Maximum Value of a Tensor in PyTorch

Home torch.max(): Maximum Value of a Tensor in PyTorch
PyTorch torch.max() Method
  • Written by krunallathiya21
  • May 12, 2025
  • 0 Com
PyTorch

The torch.max() method returns a maximum value(s) of a tensor, either across all elements or along a specified dimension. It can also return the indices of the maximum values.

torch.max() method on 1D Tensor

If you are searching for the global max value from a tensor, it will return a scalar tensor. To get the scalar value out of a scalar tensor, use the tensor.item() method.

If you are looking for a dimensional matrix that contains a maximum value and its indices, it will return a tuple of both (max_values, max_indices).

Function syntax

torch.max(input, 
          dim, 
          keepdim=False, 
          out=None,
          other)

# OR

torch.max(input)

Parameters

Argument Description
input (Tensor) It represents the input tensor from which we need to find the maximum value.
dim (int or tuple of ints, optional) It represents the dimension or dimensions to reduce. If you pass this None, all dimensions are reduced.
keepdim (bool, optional) By default, it is False, but if True, it retains reduced dimension with size 1.
out (tuple, optional) It defines the result tuple of two output tensors (max, max_indices).
other (Tensor, optional) It defines the second tensor for element-wise comparison.

Finding the global maximum

Let’s define a 1D tensor and find the maximum value from it.

import torch

tensor = torch.tensor([11, 21, 19, 46, 10])

max_value = torch.max(tensor)

print(max_value)
# Output: tensor(46)

print(max_value.type)
# Output: <built-in method type of Tensor object at 0x10257a160>

print(max_value.item())
# Output: 46

In the above code, you can see that the max value is 46 from the tensor. The output is in the scalar tensor object containing 46 value, like tensor(46). But if you want just a scalar value (not a tensor), you can use the tensor.item() method to get the exact 46.

Global Maximum within 2D Tensor

torch.max() method on 2D Tensor Let’s find the single global maximum value from a 2D tensor. 
import torch

tensor = torch.tensor([[11.0, 12.0, 31.0],
                       [44.0, 15.0, 1.0]])

max_value = torch.max(tensor)

print(max_value)

# Output: tensor(44.)

In this case, the tensor is flattened to a 1D tensor like this: [11.0, 12.0, 31.0, 44.0, 15.0, 1.0] and then finds the max value, which is 44.0. It returned as a scalar tensor.

Multiple Maxima

Ties between multiple maximum values

What if a tensor contains multiple maximum elements, or there is a tie between maximum elements? In that case, the first maximum element will be your output. For indices, it will return the index of the first occurrence.

import torch

tensor = torch.tensor([[11.0, 21.0],
                        [21.0, 13.0]])

max_value = tensor.max()

print(max_value)

# Output: tensor(21.)

Maximum along a dimension (with Indices)

Maximum along both column and row dimensions (with Indices)  

Specifying dim reduces the tensor along that dimension, returning both the maximum values and their indices.

For example, if you pass dim=0, it will find the max value along the column. If you have three columns, it will return a tensor with three values, each of which is the maximum within that column.

If you pass dim=1, it will find the max value along that row. If you have two rows, it will return a tensor with two values, each of which is the maximum from that row.

import torch

tensor = torch.tensor([[11.0, 12.0, 31.0],
                       [44.0, 15.0, 1.0]])

# Max along dim=0 (columns)
values, indices = torch.max(tensor, dim=0)
print(values)
# Output: tensor([44., 15., 31.])
print(indices)
# Output: tensor([1, 1, 0])

# Max along dim=1 (rows)
values, indices = torch.max(tensor, dim=1)
print(values)
# Output: tensor([31., 44.])
print(indices)
# Output: tensor([2, 0])

Along dim=0 (columns), the tensor is reduced by comparing elements in each column:

  • Column 1: [11.0, 44.0] → max = 44.0 (index 1)
  • Column 2: [12.0, 15.0] → max = 15.0 (index 1)
  • Column 3: [31.0, 1.0] → max = 31.0 (index 0)

Along dim=1 (rows), the tensor is reduced by comparing elements in each row:

  • Row 1: [11.0, 12.0,  31.0] → max = 31.0 (index 2)
  • Row 2: [44.0, 15.0, 1.0] → max = 44.0 (index 0)

Using keepdim=True

Let’s pass the keepdim=True argument to keep the reduced dimension with size 1.

import torch

tensor = torch.tensor([[11.0, 12.0, 31.0],
                       [44.0, 15.0, 1.0]])

max_values, max_indices = torch.max(tensor, dim=1, keepdim=True)

print("Max values:", max_values)
# Output:
# tensor([[31.],
#         [44.]])

print("Max indices:", max_indices)
# Output:
# Max indices: tensor([[2],
#                      [0]])

Along dim=1 (rows), the tensor is reduced by comparing elements in each row:

  1. Row 1: [11.0, 12.0, 31.0] → max = 31.0 (index 2)
  2. Row 2: [44.0, 15.0, 1.0] → max = 44.0 (index 0)

With keepdim=True, max_values and max_indices have shape [2, 1] instead of [2].

Pairwise Maximum (Element-wise Comparison)

The torch.max() method allows us to find the element-wise maximum value between two tensors. It compares corresponding elements of two tensors, returning a new tensor with the maximum values.

import torch

# Create two 2x2 tensors
tensor1 = torch.tensor([[11.0, 21.0],
                        [41.0, 14.0]])
tensor2 = torch.tensor([[2.0, 11.0],
                        [14.0, 31.0]])

# Calculate element-wise maximum
maximum_from_two_tensors = torch.max(tensor1, tensor2)

print(maximum_from_two_tensors)
# Output: tensor([[11., 21.],
#                 [41., 31.]])

The first element of tensor1 is compared with the first element of tensor2. That means, 11.0 vs 2.0, which is maximum? Well, 11.0. That is why the first element of the output tensor is 11.0.

Same for each element of tensor1 and its counterpart in tensor2. The output tensor has the same shape as the input tensors ([2, 2]).

Pre-allocated tensor

For efficiency, you can use the “out” argument, which allows specifying tensors to store the results. This is helpful for in-place operations as well!

To create an empty tensor that will act as a pre-allocated tensor, you can use torch.empty() or torch.zeros() method.

import torch

tensor = torch.tensor([[11.0, 21.0, 31.0],
                       [21.0, 13.0, 26.0]])

pre_allocated_tensor = torch.empty(3)
pre_allocated_idx = torch.empty(3, dtype=torch.long)

torch.max(tensor, dim=0, out=(pre_allocated_tensor, pre_allocated_idx))

print(pre_allocated_tensor)
# Output: tensor([21., 21., 31.])

print(pre_allocated_idx)
# Output: tensor([1, 0, 0])
That’s all!
Post Views: 27
LEAVE A COMMENT Cancel reply
Please Enter Your Comments *

krunallathiya21

All Categories
  • JavaScript
  • Python
  • PyTorch
site logo

Address:  TwinStar, South Block – 1202, 150 Ft Ring Road, Nr. Nana Mauva Circle, Rajkot(360005), Gujarat, India

sprintchasetechnologies@gmail.com

(+91) 8000107255.

ABOUT US
  • About
  • Team Members
  • Testimonials
  • Contact

Copyright by @SprintChase  All Rights Reserved

  • PRIVACY
  • TERMS & CONDITIONS