The torch.eq() method calculates the element-wise equality between tensors and returns a boolean tensor where True indicates elements are equal and False suggests inequal.
The output tensor has the same shape as the input, except it has boolean values.
If you want to create a mask based on the existing tensors, you can use this method, which is also helpful in classification tasks.
Syntax
torch.eq(input, other, out=None)
Parameters
| Argument | Description |
| input (Tensor) | It is the first tensor for the comparison. |
| other (Tensor or Scalar) | It is a second tensor or scalar to compare with the input tensor. |
| out (Tensor, optional) | It is an output tensor (must have dtype=torch.bool). |
Basic tensor equality
Let’s define two tensors of equal size and check their equality element-by-element.import torch first_tensor = torch.tensor([11, 2, 31]) tensor_to_compare = torch.tensor([11, 0, 31]) print(torch.eq(first_tensor, tensor_to_compare)) # Output: tensor([ True, False, True])
Since only 2 and 11 are not the same, the comparison returned False; otherwise, it resulted in True.
Scalar comparison
If the first object is a tensor and the second object is a scalar value, each element of the first tensor is compared to the scalar value, and if they are equal, it will return True; otherwise, False.
import torch first_tensor = torch.tensor([11, 2, 31]) scalar_value = 31 print(torch.eq(first_tensor, scalar_value)) # Output: tensor([False, False, True])
Broadcasting
What if the first tensor’s shape is (3, 1) and the second tensor’s shape is (3, )? Before comparison, PyTorch broadcasts both tensors to a common shape.
So, the second tensor’s shape becomes (1, 3).
So, the first tensor (3×1) compared to the second tensor (1×3) results in a (3×3) matrix.
import torch a = torch.tensor([[1], [2], [3]]) # Shape: (3, 1) b = torch.tensor([1, 2, 4]) # Shape: (3,) print(torch.eq(a, b)) # Output: # tensor([[ True, False, False], # [False, True, False], # [False, False, False]])
So, here is the row-wise explanation:
- Row 0: 1 == [1, 2, 4] → [True, False, False]
- Row 1: 2 == [1, 2, 4] → [False, True, False]
- Row 2: 3 == [1, 2, 4] → [False, False, False]
NaN Values
What if both tensors contain NaN values? How will it compare? Well, it returns False since NaN is not equal to NaN.
import torch
x = torch.tensor([float('nan'), 21.0])
y = torch.tensor([float('nan'), 21.0])
print(torch.eq(x, y))
# tensor([False, True])
Empty tensors
If both input tensors are empty, it returns an empty tensor as well.
import torch empty_t1 = torch.tensor([]) empty_t2 = torch.tensor([]) output = torch.eq(empty_t1, empty_t2) print(output) # Output: tensor([], dtype=torch.bool)That’s all!
