The torch.max() method returns a maximum value(s) of a tensor, either across all elements or along a specified dimension. It can also return the indices of the maximum values.

If you are searching for the global max value from a tensor, it will return a scalar tensor. To get the scalar value out of a scalar tensor, use the tensor.item() method.
If you are looking for a dimensional matrix that contains a maximum value and its indices, it will return a tuple of both (max_values, max_indices).
Function syntax
torch.max(input, dim, keepdim=False, out=None, other) # OR torch.max(input)
Parameters
Argument | Description |
input (Tensor) | It represents the input tensor from which we need to find the maximum value. |
dim (int or tuple of ints, optional) | It represents the dimension or dimensions to reduce. If you pass this None, all dimensions are reduced. |
keepdim (bool, optional) | By default, it is False, but if True, it retains reduced dimension with size 1. |
out (tuple, optional) | It defines the result tuple of two output tensors (max, max_indices). |
other (Tensor, optional) | It defines the second tensor for element-wise comparison. |
Finding the global maximum
Let’s define a 1D tensor and find the maximum value from it.
import torch tensor = torch.tensor([11, 21, 19, 46, 10]) max_value = torch.max(tensor) print(max_value) # Output: tensor(46) print(max_value.type) # Output: <built-in method type of Tensor object at 0x10257a160> print(max_value.item()) # Output: 46
In the above code, you can see that the max value is 46 from the tensor. The output is in the scalar tensor object containing 46 value, like tensor(46). But if you want just a scalar value (not a tensor), you can use the tensor.item() method to get the exact 46.
Global Maximum within 2D Tensor

import torch tensor = torch.tensor([[11.0, 12.0, 31.0], [44.0, 15.0, 1.0]]) max_value = torch.max(tensor) print(max_value) # Output: tensor(44.)
In this case, the tensor is flattened to a 1D tensor like this: [11.0, 12.0, 31.0, 44.0, 15.0, 1.0] and then finds the max value, which is 44.0. It returned as a scalar tensor.
Multiple Maxima

What if a tensor contains multiple maximum elements, or there is a tie between maximum elements? In that case, the first maximum element will be your output. For indices, it will return the index of the first occurrence.
import torch tensor = torch.tensor([[11.0, 21.0], [21.0, 13.0]]) max_value = tensor.max() print(max_value) # Output: tensor(21.)
Maximum along a dimension (with Indices)

Specifying dim reduces the tensor along that dimension, returning both the maximum values and their indices.
For example, if you pass dim=0, it will find the max value along the column. If you have three columns, it will return a tensor with three values, each of which is the maximum within that column.
If you pass dim=1, it will find the max value along that row. If you have two rows, it will return a tensor with two values, each of which is the maximum from that row.
import torch tensor = torch.tensor([[11.0, 12.0, 31.0], [44.0, 15.0, 1.0]]) # Max along dim=0 (columns) values, indices = torch.max(tensor, dim=0) print(values) # Output: tensor([44., 15., 31.]) print(indices) # Output: tensor([1, 1, 0]) # Max along dim=1 (rows) values, indices = torch.max(tensor, dim=1) print(values) # Output: tensor([31., 44.]) print(indices) # Output: tensor([2, 0])
Along dim=0 (columns), the tensor is reduced by comparing elements in each column:
- Column 1: [11.0, 44.0] → max = 44.0 (index 1)
- Column 2: [12.0, 15.0] → max = 15.0 (index 1)
- Column 3: [31.0, 1.0] → max = 31.0 (index 0)
Along dim=1 (rows), the tensor is reduced by comparing elements in each row:
- Row 1: [11.0, 12.0, 31.0] → max = 31.0 (index 2)
- Row 2: [44.0, 15.0, 1.0] → max = 44.0 (index 0)
Using keepdim=True
Let’s pass the keepdim=True argument to keep the reduced dimension with size 1.
import torch tensor = torch.tensor([[11.0, 12.0, 31.0], [44.0, 15.0, 1.0]]) max_values, max_indices = torch.max(tensor, dim=1, keepdim=True) print("Max values:", max_values) # Output: # tensor([[31.], # [44.]]) print("Max indices:", max_indices) # Output: # Max indices: tensor([[2], # [0]])
Along dim=1 (rows), the tensor is reduced by comparing elements in each row:
- Row 1: [11.0, 12.0, 31.0] → max = 31.0 (index 2)
- Row 2: [44.0, 15.0, 1.0] → max = 44.0 (index 0)
With keepdim=True, max_values and max_indices have shape [2, 1] instead of [2].
Pairwise Maximum (Element-wise Comparison)
The torch.max() method allows us to find the element-wise maximum value between two tensors. It compares corresponding elements of two tensors, returning a new tensor with the maximum values.
import torch # Create two 2x2 tensors tensor1 = torch.tensor([[11.0, 21.0], [41.0, 14.0]]) tensor2 = torch.tensor([[2.0, 11.0], [14.0, 31.0]]) # Calculate element-wise maximum maximum_from_two_tensors = torch.max(tensor1, tensor2) print(maximum_from_two_tensors) # Output: tensor([[11., 21.], # [41., 31.]])
The first element of tensor1 is compared with the first element of tensor2. That means, 11.0 vs 2.0, which is maximum? Well, 11.0. That is why the first element of the output tensor is 11.0.
Same for each element of tensor1 and its counterpart in tensor2. The output tensor has the same shape as the input tensors ([2, 2]).
Pre-allocated tensor
For efficiency, you can use the “out” argument, which allows specifying tensors to store the results. This is helpful for in-place operations as well!
To create an empty tensor that will act as a pre-allocated tensor, you can use torch.empty() or torch.zeros() method.
import torch tensor = torch.tensor([[11.0, 21.0, 31.0], [21.0, 13.0, 26.0]]) pre_allocated_tensor = torch.empty(3) pre_allocated_idx = torch.empty(3, dtype=torch.long) torch.max(tensor, dim=0, out=(pre_allocated_tensor, pre_allocated_idx)) print(pre_allocated_tensor) # Output: tensor([21., 21., 31.]) print(pre_allocated_idx) # Output: tensor([1, 0, 0])That’s all!