Sprint Chase Technologies
  • Home
  • About
    • Why Choose Us
    • Contact Us
    • Team Members
    • Testimonials
  • Services
    • Web Development
    • Web Application Development
    • Mobile Application Development
    • Web Design
    • UI/UX Design
    • Social Media Marketing
    • Projects
  • Blog
    • PyTorch
    • Python
    • JavaScript
  • IT Institute
menu
close

Need Help? Talk to an Expert

+91 8000107255
Sprint Chase Technologies
  • Home
  • About
    • Why Choose Us
    • Contact Us
    • Team Members
    • Testimonials
  • Services
    • Web Development
    • Web Application Development
    • Mobile Application Development
    • Web Design
    • UI/UX Design
    • Social Media Marketing
    • Projects
  • Blog
    • PyTorch
    • Python
    • JavaScript
  • IT Institute

Need Help? Talk to an Expert

+91 8000107255

torch.masked_select() Method in PyTorch

Home torch.masked_select() Method in PyTorch
PyTorch torch.masked_select() Method
  • Written by krunallathiya21
  • June 13, 2025
  • 0 Com
PyTorch

The torch.masked_select() method selects elements from the input tensor based on the boolean mask. If the boolean tensor contains True, the corresponding element of the input tensor will be included in the output tensor.

torch.mask_select()

So, the output tensor’s number of elements equals the number of True values in the mask.

It returns a 1D tensor containing the selected elements. The shapes of the mask tensor and the input tensor don’t need to match, but they must be broadcastable.

The most probable use case is to filter out values from the tensor based on the conditions. Because conditions generally return True or False.

Syntax

torch.masked_select(input, mask, out)

Parameters

Argument Description
input (Tensor) It represents an input tensor from which the elements are selected.
mask (Tensor)

It is a boolean tensor of the same shape as input or broadcastable to it.

out (Tensor, optional)

It is the output tensor to store the result.

Element selection using a Boolean mask

Let’s define a 1D tensor and select elements from it where the mask is True.
import torch

input_tensor = torch.tensor([11, 12, 31, 41, 15])

mask = torch.tensor([True, False, True, False, True])

filtered_tensor = torch.masked_select(input_tensor, mask)

print(filtered_tensor)

# Output: tensor([11, 31, 15])

Filtering with a condition

torch.masked_select() method with custom conditions

Let’s define a mask based on a condition. Let’s say we want to filter out the values, and the output tensor should contain only values that are greater than 50.

import torch

input_2d_tensor = torch.tensor([[11, 52, 13], [41, 51, 61]])

mask = input_2d_tensor > 50

filtered_2d_tensor = torch.masked_select(input_2d_tensor, mask)

print(filtered_2d_tensor)

# Output: tensor([52, 51, 61])

Broadcasting with different shapes

Even if the mask’s shape is not the same as the input, if it is broadcastable, it won’t throw any error.
import torch

input_2d_tensor = torch.tensor([[11, 52, 13], [41, 51, 61]])

mask = torch.tensor([[True], [False]])  # Shape (2, 1)

filtered_broadcasting = torch.masked_select(input_2d_tensor, mask)

print(filtered_broadcasting)

# Output: tensor([11, 52, 13])

Empty mask

Empty mask

What if the mask contains all the False values? Would this method return any elements?

Since the mask contains no True values, it will return an empty tensor.

import torch

input_tensor = torch.tensor([2, 3, 5, 6])

mask = torch.tensor([False, False, False, False])

empty_result = torch.masked_select(input_tensor, mask)

print(empty_result)

# Output: tensor([], dtype=torch.int64)

Using the out Parameter

If you have a pre-allocated tensor, you can use the “out” argument to store the result of the masked_select() method.

import torch

input_tensor = torch.tensor([1, 21, 3, 41])

mask = torch.tensor([True, False, True, False])

out_tensor = torch.empty(2, dtype=input_tensor.dtype)

result = torch.masked_select(input_tensor, mask, out=out_tensor)

print(result)

# Output: tensor([1, 3])

The result is stored in out_tensor, which must have sufficient size. Otherwise, it will throw an error.

Filter out invalid values

Let’s say our input tensor contains NaN values, and we want to filter them out. How do we do that? Well, that’s where the torch.isnan(data) method with masked_select() method comes into play.

import torch

data = torch.tensor([21.0, float('nan'), 19.0, float('nan')])

mask = ~torch.isnan(data)

clean_data = torch.masked_select(data, mask)

print(clean_data)

# Output: tensor([21., 19.])
That’s it!
Post Views: 2
LEAVE A COMMENT Cancel reply
Please Enter Your Comments *

krunallathiya21

All Categories
  • JavaScript
  • Python
  • PyTorch
site logo

Address:  TwinStar, South Block – 1202, 150 Ft Ring Road, Nr. Nana Mauva Circle, Rajkot(360005), Gujarat, India

sprintchasetechnologies@gmail.com

(+91) 8000107255.

ABOUT US
  • About
  • Team Members
  • Testimonials
  • Contact

Copyright by @SprintChase  All Rights Reserved

  • PRIVACY
  • TERMS & CONDITIONS