Sprint Chase Technologies
  • Home
  • About
    • Why Choose Us
    • Contact Us
    • Team Members
    • Testimonials
  • Services
    • Web Development
    • Web Application Development
    • Mobile Application Development
    • Web Design
    • UI/UX Design
    • Social Media Marketing
    • Projects
  • Blog
    • PyTorch
    • Python
    • JavaScript
  • IT Institute
menu
close

Need Help? Talk to an Expert

+91 8000107255
Sprint Chase Technologies
  • Home
  • About
    • Why Choose Us
    • Contact Us
    • Team Members
    • Testimonials
  • Services
    • Web Development
    • Web Application Development
    • Mobile Application Development
    • Web Design
    • UI/UX Design
    • Social Media Marketing
    • Projects
  • Blog
    • PyTorch
    • Python
    • JavaScript
  • IT Institute

Need Help? Talk to an Expert

+91 8000107255

torch.isnan: How to Check NaN Value in PyTorch

Home torch.isnan: How to Check NaN Value in PyTorch
How to check for NaN values in PyTorch Tensor
  • Written by krunallathiya21
  • December 21, 2024
  • 0 Com
PyTorch

The torch.isnan() method identifies elements in a tensor that are NaN (Not a Number). It accepts an input element and returns True if it is “NaN” and False otherwise.

torch.isnan() function in PyTorch

The above figure depicts how the method works! It evaluates each element of the tensor independently.

The shape of a result tensor is the same as the input tensor, where True indicates the NaN value and False indicates otherwise.

It does not modify the original tensor and does not apply to integer or boolean tensors (raises an error).

It evaluates each element independently.

Syntax

torch.isnan(element)

Parameters

Argument Description
element It is an input value that will be checked for the “NaN” value.

Detecting NaN values

Basic usage of torch.isnan() function in PyTorch
import torch

# Define a tensor
tr = torch.tensor([1.0, float('nan'), 3.0, float('inf'), float('nan')])

# Print a tensor
print("Original tensor:", tr)

# Mask NaN values
isnan_mask = torch.isnan(tr)

# Print the mask
print("Detecting NaN values:", isnan_mask)

# Output
# Original tensor: tensor([1., nan, 3., inf, -inf])
# Detecting NaN values: tensor([False,  True, False, False, True])

You can see in the above code that the nan value is replaced by True, and other values are replaced by False. The shape of the “isnan_mask” is the same as “tr”.

Filtering Out NaNs

You can filter out the NaN values from the tensor using a combination of torch.isnan(), bitwise not operator(~), and boolean masking.

First, we will determine which values are NaNs using the isnan() method, which returns a tensor of boolean elements.

Then, we will invert the boolean values using the bitwise not operator.

Finally, using boolean masking, we will remove them from the tensor.

import torch

tensor = torch.tensor([1.0, float('nan'), 3.0, float('nan')])

filtered_out_nan = tensor[~torch.isnan(tensor)]

print(filtered_out_nan)

# Output: tensor([1., 3.])

Counting the number of NaNs

If you want to count the number of NaN values in a tensor, you can use the combination of isnan() and sum() functions.

import torch

# Define a tensor
tr = torch.tensor([1.0, float('nan'), 3.0, float('inf'), float('nan')])

# Print a tensor
print("Original tensor:", tr)

# Counting NaN values
total_nans = torch.isnan(tr).sum()

# Print the count
print("Total number of NaNs:", total_nans)

# Output
# Original tensor: tensor([1., nan, 3., inf, nan])
# Total number of NaNs: tensor(2)

As expected, since there are two NaN values, the output is tensor 2.

Using with torch.where() to replace NaN

For conditionally replacing one value with another, you can use the torch.where() method. In our case, we can replace the NaN value. But before that, we first need to identify the NaN value using torch.isnan() method and then replace it with an appropriate scalar value.

import torch

tensor = torch.tensor([21.1, float('nan'), 1.9])

replaced_tensor = torch.where(torch.isnan(tensor), torch.tensor(0.0), tensor)

print(replaced_tensor)

# Output: tensor([21.1000,  0.0000,  1.9000])
The above program shows that we replaced the NaN value with 0.

Replacing NaN values with a mean

You can replace NaN values with a mean of the dataset values using a torch.isnan() and mean() functions.
import torch


def impute_with_mean(tensor):
    # Replacing NaN values in a tensor with the mean of the non-NaN values.
    nan_mask = torch.isnan(tensor)
    if nan_mask.any():  # Only calculate mean if there are NaNs
        non_nan_values = tensor[~nan_mask]
        mean_val = non_nan_values.mean()
        tensor[nan_mask] = mean_val
    return tensor


# Example usage:
tr = torch.tensor([1.0, 2.0, float('nan'), 4.0, float('nan'), 6.0])
print("Original tensor:", tr)
imputed_tr = impute_with_mean(tr)
print("Imputed tensor:", imputed_tr)

# Output
# Original tensor: tensor([1., 2., nan, 4., nan, 6.])
# Imputed tensor: tensor([1.0000, 2.0000, 3.2500, 4.0000, 3.2500, 6.0000])

You can see that the mean of the original tensor is 3.2500, and in the imputed tensor, the NaN values are replaced with 3.2500.

Checking for any NaN value

If you want to check if your input tensor contains any NaN value, you can use the torch.isnan().any().item() expression.

import torch

tensor = torch.tensor([21.1, float('nan'), 1.9])

has_nan = torch.isnan(tensor).any().item()

print(has_nan)

# Output: True

If the tensor contains no NaN value, it will return False.

import torch

tensor = torch.tensor([21.1, 19.1, 1.9])

has_nan = torch.isnan(tensor).any().item()

print(has_nan)

# Output: False

Handling Multi-dimensional Tensors

The isnan() method can handle multidimensional tensors, such as 2D, 3D, or more. If the input is a 2D tensor, the output boolean tensor will be 2D, containing True values for NaN and False otherwise.

import torch

multi_tensor = torch.tensor([[1.0, float('nan')], [float('inf'), 4.0]])

print(torch.isnan(multi_tensor))

# Output:
# tensor([[False,  True],
#         [False, False]])

Handling Integer Tensors

The NaN is a floating-point concept, and it does not support integers. So, if you are working with integer tensors, NaN won’t be there, and you will conclude that the result is False for integer tensors.

import torch

int_tensor = torch.tensor([1, 2, 3], dtype=torch.int64)

is_nan_tensor = torch.isnan(int_tensor)

print(is_nan_tensor)

# Output: tensor([False, False, False])

Conclusion

To check if an input element is NaN in PyTorch, use torch.isnan() function.
Post Views: 202
LEAVE A COMMENT Cancel reply
Please Enter Your Comments *

krunallathiya21

All Categories
  • JavaScript
  • Python
  • PyTorch
image
image
image
image
image
logo

Address: 121 King Street, Melbourne Victoria 3000 Australia.

hamela@example.com

+36 (0) 1779 228 338..

ABOUT US
  • About
  • Team Members
  • Testimonials
  • Contact
SUPPORT
  • Content Strategy
  • Copywriting
  • Content Marketing
  • Web Design
QUICK LINKS
  • Marketplace
  • Documentation
  • Customers
  • Carrers
INSTAGRAM

Copyright by @Themesflat  All Rights Reserved

  • PRIVACY
  • TERMS & CONDITIONS