The torch.eq() method calculates the element-wise equality between tensors and returns a boolean tensor where True indicates elements are equal and False suggests inequal.

If you want to create a mask based on the existing tensors, you can use this method, which is also helpful in classification tasks.
Syntax
torch.eq(input, other, out=None)
Parameters
Argument | Description |
input (Tensor) | It is the first tensor for the comparison. |
other (Tensor or Scalar) | It is a second tensor or scalar to compare with the input tensor. |
out (Tensor, optional) | It is an output tensor (must have dtype=torch.bool). |
Basic tensor equality
Let’s define two tensors of equal size and check their equality element-by-element.import torch first_tensor = torch.tensor([11, 2, 31]) tensor_to_compare = torch.tensor([11, 0, 31]) print(torch.eq(first_tensor, tensor_to_compare)) # Output: tensor([ True, False, True])
Since only 2 and 11 are not the same, the comparison returned False; otherwise, it resulted in True.
Scalar comparison

If the first object is a tensor and the second object is a scalar value, each element of the first tensor is compared to the scalar value, and if they are equal, it will return True; otherwise, False.
import torch first_tensor = torch.tensor([11, 2, 31]) scalar_value = 31 print(torch.eq(first_tensor, scalar_value)) # Output: tensor([False, False, True])
Broadcasting
What if the first tensor’s shape is (3, 1) and the second tensor’s shape is (3, )? Before comparison, PyTorch broadcasts both tensors to a common shape.
So, the second tensor’s shape becomes (1, 3).
So, the first tensor (3×1) compared to the second tensor (1×3) results in a (3×3) matrix.
import torch a = torch.tensor([[1], [2], [3]]) # Shape: (3, 1) b = torch.tensor([1, 2, 4]) # Shape: (3,) print(torch.eq(a, b)) # Output: # tensor([[ True, False, False], # [False, True, False], # [False, False, False]])
So, here is the row-wise explanation:
- Row 0: 1 == [1, 2, 4] → [True, False, False]
- Row 1: 2 == [1, 2, 4] → [False, True, False]
- Row 2: 3 == [1, 2, 4] → [False, False, False]
NaN Values
What if both tensors contain NaN values? How will it compare? Well, it returns False since NaN is not equal to NaN.
import torch x = torch.tensor([float('nan'), 21.0]) y = torch.tensor([float('nan'), 21.0]) print(torch.eq(x, y)) # tensor([False, True])
Empty tensors
If both input tensors are empty, it returns an empty tensor as well.
import torch empty_t1 = torch.tensor([]) empty_t2 = torch.tensor([]) output = torch.eq(empty_t1, empty_t2) print(output) # Output: tensor([], dtype=torch.bool)That’s all!