The torch.where() method either returns the indices of elements that satisfy the conditions or constructs a new tensor by selecting elements from two input tensors based on the condition mask, depending on the arguments you provide.
The definition sounds complex, but it’s straightforward. It states mainly two objectives:
- Index identification: It returns just the indices of elements that the conditions evaluate to True.
- Element selection: It returns a new tensor with new elements selected from one of two input tensors based on the condition mask.
This method is helpful in masking, thresholding, conditional calculations, and data cleaning.
Function signature
# (1) Condition-only overload:
torch.where(condition: Tensor)
# (2) Full overload:
torch.where(condition: Tensor,
x: Tensor or scalar,
y: Tensor or scalar)
Parameters
| Argument | Description |
| condition (BoolTensor) | It is a boolean tensor that determines whether to select x or y. |
| x, y (Tensor or Scalar) | It is scalars or tensors from which the output is selected based on the condition. If the condition is True, the selection will be done from x, and if False, the selection will be done from y.
Both tensors (x, y) must be broadcastable to a standard shape. |
Index identification
Let’s define a boolean mask, which is a 2X3 tensor of dtype torch.bool.
import torch
mask = torch.tensor([[True, False, True],
[False, True, False]])
idx = torch.where(mask)
print(idx)
# Output: (tensor([0, 0, 1]), tensor([0, 2, 1]))
How do we interpret the output? You can see that the output is a tuple of 1D index tensors, one for each dimension, identifying all positions where the condition is True.
That means it returns coordinates of elements that have True values.
The mask tensor is 2×3, meaning it has two rows and three columns.
The output has two 1D tensors.
- [0, 0, 1]: output tensor 1: indices of rows: Always starts with 0.
- [0, 2, 1]: output tensor 2: indices of columns: Always starts with 0.
To locate elements at a specific position, we need to combine the row and column indices of the respective tensors.
Let’s combine the output and find the True values. We now have indices for rows and columns for True values.
The first element of the first 1D output tensor is 0.
The first element of the second 1D tensor is also 0, which means (0, 0), and it is the coordinate of the first True value.
If you check the mask tensor, the first element is True, and its coordinates are (0, 0), which corresponds to the 0th row and 0th column.
The second value for both tensors is (0, 2), which means that at row 0 and column 2, there is a second True value.
The third index for both tensors is (1, 1), which means in the second row and second column, there is a third True value.
To summarize this, the output tuple means there are three True entries, located at:
- (row=0, col=0)
- (row=0, col=2)
- (row=1, col=1)
You can see that the where() method quickly found the coordinates of nonzero or True elements.
If you need a single tensor of shape (N, D) with all indices, you can stack them using torch.stack() method.
import torch
mask = torch.tensor([[True, False, True],
[False, True, False]])
coords = torch.stack(torch.where(mask), dim=1)
print(coords)
# Output:
# tensor([[0, 0],
# [0, 2],
# [1, 1]])
Now, we get a single tensor containing the coordinates of the True value.
Element-wise Selection
Let’s replace negative values with zero in the tensor.import torch tensor = torch.tensor([-11, 21, -18, 19, -6]) output = torch.where(tensor > 0, tensor, torch.zeros_like(tensor)) print(output) # Output: tensor([ 0, 21, 0, 19, 0])
In this code, we use torch.where() function to mask the values with True or False based on whether it is positive or negative.
After that, using torch.zeros_like() function, we replaced False values with 0s.
Broadcasting
import torch condition = torch.tensor([[True, False], [False, True]]) result = torch.where(condition, 1.0, 0.0) print(result) # Output: # tensor([[1., 0.], # [0., 1.]])
In this code, the condition is a Boolean mask tensor.
The value of x will be replaced with 1.0 when the condition is True.
The value of y will be replaced with 0.0 when the condition is False.
If the condition had a different shape, say (2, 1), broadcasting would automatically expand the scalar 1.0 and 0.0 to match this shape during the operation.
import torch condition = torch.tensor([[True, False], [False, True]]) result_of_broadcasting = torch.where(condition[:, :1], 1.0, 0.0) print(result_of_broadcasting) # Output: # tensor([[1.], # [0.]])
Multiple conditions
What if we have multiple conditions to apply to a tensor to filter out more focused results? In that case, we can use nested usage of torch.where() for multiple condition handling.
Let’s say we have a number with three conditions:
- If a number is greater than 10, keep it.
- If a number is between 5 and 10, set it to 0.
- If a number is less than 5, set it to -1.
import torch
tensor = torch.tensor([4, 15, 7, 21, 10])
# Condition 1: x < 5 → set to -1
# Condition 2: 5 <= x <= 10 → set to 0
# Else (x > 10) → keep original
filtered_tensor = torch.where(tensor < 5, torch.tensor(-1),
torch.where(tensor <= 10, torch.tensor(0), tensor))
print("Original tensor:", tensor)
print("Filtered tensor:", filtered_tensor)
# Output:
# Original tensor: tensor([ 4, 15, 7, 21, 10])
# Filtered tensor: tensor([ -1, 15, 0, 21, 0])
That’s all!
