In the world of machine learning models, if your datasets contain NaN values, then they can introduce bias in the model, resulting in improper prediction. Most ML models cannot handle NaN values, and therefore, you need to either detect and replace those NaN values or clean the dataset.
To check if an input element is NaN in PyTorch, use the torch.isnan() function.
torch.isnan()
The torch.isnan() function accepts an input element and returns True if the element is “NaN” and False otherwise. The shape of a result tensor is the same as the input tensor, where True indicates the NaN value and False indicates otherwise.Syntax
torch.isnan(element)
Parameters
Name | Value |
element | It is an input value that will be checked for the “NaN” value. |
Basic usage
import torch # Define a tensor tr = torch.tensor([1.0, float('nan'), 3.0, float('inf'), float('nan')]) # Print a tensor print("Original tensor:", tr) # Mask NaN values isnan_mask = torch.isnan(tr) # Print the mask print("Detecting NaN values:", isnan_mask) # Output # Original tensor: tensor([1., nan, 3., inf, -inf]) # Detecting NaN values: tensor([False, True, False, False, True])
You can see in the above code that the nan value is replaced by True, and other values are replaced by False. The shape of the “isnan_mask” is the same as “tr”.
Counting the number of NaN values
If you want to count the number of NaN values in a tensor, you can use the combination of isnan() and sum() functions.
import torch # Define a tensor tr = torch.tensor([1.0, float('nan'), 3.0, float('inf'), float('nan')]) # Print a tensor print("Original tensor:", tr) # Counting NaN values total_nans = torch.isnan(tr).sum() # Print the count print("Total number of NaNs:", total_nans) # Output # Original tensor: tensor([1., nan, 3., inf, nan]) # Total number of NaNs: tensor(2)
As expected, since there are two NaN values, the output is tensor 2.
Handling NaN values with a mean of the dataset
You can replace NaN values with a mean of the dataset values using a torch.isnan() and mean() functions.import torch def impute_with_mean(tensor): # Replacing NaN values in a tensor with the mean of the non-NaN values. nan_mask = torch.isnan(tensor) if nan_mask.any(): # Only calculate mean if there are NaNs non_nan_values = tensor[~nan_mask] mean_val = non_nan_values.mean() tensor[nan_mask] = mean_val return tensor # Example usage: tr = torch.tensor([1.0, 2.0, float('nan'), 4.0, float('nan'), 6.0]) print("Original tensor:", tr) imputed_tr = impute_with_mean(tr) print("Imputed tensor:", imputed_tr) # Output # Original tensor: tensor([1., 2., nan, 4., nan, 6.]) # Imputed tensor: tensor([1.0000, 2.0000, 3.2500, 4.0000, 3.2500, 6.0000])
You can see that the mean of the original tensor is 3.2500, and in the imputed tensor, NaN values are replaced with 3.2500.