The torch.masked_select() method selects elements from the input tensor based on the boolean mask. If the boolean tensor contains True, the corresponding element of the input tensor will be included in the output tensor.

So, the output tensor’s number of elements equals the number of True values in the mask.
It returns a 1D tensor containing the selected elements. The shapes of the mask tensor and the input tensor don’t need to match, but they must be broadcastable.
The most probable use case is to filter out values from the tensor based on the conditions. Because conditions generally return True or False.
Syntax
torch.masked_select(input, mask, out)
Parameters
Argument | Description |
input (Tensor) | It represents an input tensor from which the elements are selected. |
mask (Tensor) |
It is a boolean tensor of the same shape as input or broadcastable to it. |
out (Tensor, optional) |
It is the output tensor to store the result. |
Element selection using a Boolean mask
Let’s define a 1D tensor and select elements from it where the mask is True.import torch input_tensor = torch.tensor([11, 12, 31, 41, 15]) mask = torch.tensor([True, False, True, False, True]) filtered_tensor = torch.masked_select(input_tensor, mask) print(filtered_tensor) # Output: tensor([11, 31, 15])
Filtering with a condition

Let’s define a mask based on a condition. Let’s say we want to filter out the values, and the output tensor should contain only values that are greater than 50.
import torch input_2d_tensor = torch.tensor([[11, 52, 13], [41, 51, 61]]) mask = input_2d_tensor > 50 filtered_2d_tensor = torch.masked_select(input_2d_tensor, mask) print(filtered_2d_tensor) # Output: tensor([52, 51, 61])
Broadcasting with different shapes
Even if the mask’s shape is not the same as the input, if it is broadcastable, it won’t throw any error.import torch input_2d_tensor = torch.tensor([[11, 52, 13], [41, 51, 61]]) mask = torch.tensor([[True], [False]]) # Shape (2, 1) filtered_broadcasting = torch.masked_select(input_2d_tensor, mask) print(filtered_broadcasting) # Output: tensor([11, 52, 13])
Empty mask

What if the mask contains all the False values? Would this method return any elements?
Since the mask contains no True values, it will return an empty tensor.
import torch input_tensor = torch.tensor([2, 3, 5, 6]) mask = torch.tensor([False, False, False, False]) empty_result = torch.masked_select(input_tensor, mask) print(empty_result) # Output: tensor([], dtype=torch.int64)
Using the out Parameter
If you have a pre-allocated tensor, you can use the “out” argument to store the result of the masked_select() method.
import torch input_tensor = torch.tensor([1, 21, 3, 41]) mask = torch.tensor([True, False, True, False]) out_tensor = torch.empty(2, dtype=input_tensor.dtype) result = torch.masked_select(input_tensor, mask, out=out_tensor) print(result) # Output: tensor([1, 3])
The result is stored in out_tensor, which must have sufficient size. Otherwise, it will throw an error.
Filter out invalid values
Let’s say our input tensor contains NaN values, and we want to filter them out. How do we do that? Well, that’s where the torch.isnan(data) method with masked_select() method comes into play.
import torch data = torch.tensor([21.0, float('nan'), 19.0, float('nan')]) mask = ~torch.isnan(data) clean_data = torch.masked_select(data, mask) print(clean_data) # Output: tensor([21., 19.])That’s it!