The torch.sub() method performs element-wise subtraction on two tensors or subtracts a scalar from a tensor, with optional scaling. It supports Pytorch’s broadcasting to a standard shape, type promotion, and integer, float, and complex inputs.

The above figure shows how the torch.sub() method works internally.
The torch.subtract() is an alias for torch.sub().
Here is the formula for subtraction:
output = input - (alpha * other)
We will discuss each variable in the upcoming parameter section.
Syntax
torch.sub(input, other, alpha=1, out=None)
Parameters
Argument | Description |
input (Tensor) | It is the first input tensor. |
other (Tensor or Number) | It is either a Tensor or a scalar value to subtract from the input tensor. |
alpha (int, optional) | It is a scalar multiplier to scale the other tensor. By default, its value is 1, but you can set it to any value you prefer. |
out (Tensor, optional) | It is an output tensor to store the result of subtraction. By default, it is None. If you have a pre-allocated tensor, you can use this tensor. |
Element-wise subtraction
Let’s define two tensors with the same shape and dtype and perform element-wise subtraction without scaling anything.
import torch first_tensor = torch.tensor([11.0, 21.0, 18.0]) other_tensor = torch.tensor([9.0, 19.0, 10.0]) subtract_tensor = torch.sub(first_tensor, other_tensor) print(subtract_tensor) # Output: tensor([2., 2., 8.])
In the above code, we are subtracting the first element of the other_tensor from the first element of the first_tensor, which is 11.0 – 9.0 = 2.0. Therefore, the first element of the subtract_tensor is 2.0.
The same for the second elements for both tensor and their subtraction is 2, and the third element is 8.
Scalar subtraction

If you subtract a single value or scalar from a tensor, it will broadcast the scalar to all elements.
import torch tensor = torch.tensor([12.0, 13.0, 14.0]) scalar = 4 scalar_subtract_tensor = torch.subtract(tensor, scalar) print(scalar_subtract_tensor) # Output: tensor([ 8., 9., 10.])
Alpha scaling

import torch first_tensor = torch.tensor([10.0, 20.0, 30.0]) other_tensor = torch.tensor([1.0, 3.0, 5.0]) alpha_scaling_tensor = torch.subtract(first_tensor, other_tensor, alpha=3) # (first_tensor - other_tensor x 3) print(alpha_scaling_tensor) # Output: tensor([ 7., 11., 15.])
In the above code, we calculated first_tensor – 3 * other_tensor, where other_tensor is scaled by alpha = 3 before the subtraction.
That means for the first element, 1.0 x 3 = 3.0 and 10.0 – 3.0 = 7.0. So, the first element of the output tensor is 7.
The same logic works for the second element of the output tensor. 3.0 x 3 = 9.0 and 20.0 – 9.0 = 11.
For the third element: 5.0 x 3 = 15.0, and 30.0 – 15.0 = 15.
Broadcasting
With compatible shapes, you can subtract tensors effortlessly.
import torch a = torch.tensor([[5.0, 15.0], [4.0, 14.0]]) b = torch.tensor([3.0, 4.0]) broadcasting_subtraction = torch.subtract(a, b) print(broadcasting_subtraction) # Output: # tensor([[ 2., 11.], # [ 1., 10.]])
In this code example, b is broadcast to match a’s shape by adding once more [3.0, 4.0] to the tensor b. And then we will subtract [[3.0, 4.0]. [3.0, 4.0]] from tensor a.
That’s it!