The torch.argwhere() method returns the indices of non-zero elements, providing a convenient way to locate elements that satisfy a condition.

In short, it returns a 2D tensor of shape (N, D), where,
- N is the number of non-zero elements.
- D is the number of dimensions in the input tensor.
Each row contains the indices of a non-zero element.
The torch.argwhere() method differs from torch.where(), which returns indices for elements satisfying a condition but is typically used with a condition tensor and can return multiple tensors.
Syntax
torch.argwhere(input)
Parameters
Argument | Description |
input (Tensor) | It is an input tensor of any shape in which the method identifies where elements are non-zero (or True for boolean tensors). |
Finding Non-Zero elements in a Numerical Tensor
Let’s identify the indices of non-zero elements in a 2D tensor.You can locate any value in the 2D tensor or matrix by its rows and columns, which start with the 0th index.
import torch tensor = torch.tensor([[0, 1], [2, 0]]) indices = torch.argwhere(tensor) print(indices) # Output: # tensor([[0, 1], # [1, 0]])
Row | Column | Value | Is Non-Zero? |
---|---|---|---|
0 | 0 | 0 | No |
0 | 1 | 1 | Yes |
1 | 0 | 2 | Yes |
1 | 1 | 0 | No |
- At position [0, 1] → value 1
- At position [1, 0] → value 2
Hence, torch.argwhere(tensor) returns: tensor[[0, 1], [1, 0]].
Working with Boolean tensors
Locate True values in a Boolean tensor.import torch bool_tensor = torch.tensor([[False, True, False], [True, False, True]]) indices = torch.argwhere(bool_tensor) print(indices) # Output: # tensor([[0, 1], # [1, 0], # [1, 2]])
At position [0, 1], there is a True value. At position [1, 0], there is a True value; at position [1, 2], there is a True value.
Higher-Dimensional Tensors
Let’s locate the non-zero elements in a 3D tensor.
import torch tensor_3d = torch.tensor([[[0, 1], [0, 0]], [[21, 0], [13, 41]]]) indices = torch.argwhere(tensor_3d) print(indices) # Output: # tensor([[0, 0, 1], # [1, 0, 0], # [1, 1, 0], # [1, 1, 1]])
Empty Tensor
Let’s check out the scenario where the tensor has no non-zero elements.For example, if we created an input tensor using the torch.zeros() method, it would have no non-zero elements.
So, this method will return an empty tensor due to no non-zero elements.
import torch zero_tensor = torch.zeros(2, 3) indices = torch.argwhere(zero_tensor) print(indices) # Output: # tensor([], size=(0, 2), dtype=torch.int64)
Using with Conditions
One advantage of this method is that it can be used with conditions to find indices that meet specific criteria.
import torch tensor = torch.tensor([[11, 2, 3], [4, 5, 61]]) indices = torch.argwhere(tensor > 10) print(indices) # Output: # tensor([[0, 0], # [1, 2]])
In the above code, the condition tensor > 10 creates a boolean tensor, and argwhere returns indices where the condition is True.
GPU Compatibility
You can use this method on a GPU without any errors.
import torch tensor = torch.tensor([[0, 1, 0], [2, 0, 3]], device='cuda') indices = torch.argwhere(tensor) print(indices)
