The torch.sum() method calculates the sum of a tensor’s elements across all elements or along a specific dimension.

Syntax
torch.sum(input, dim=None, keepdim=False, dtype=None) # OR torch.sum(input, *, dtype=None)
Parameters
Argument | Description |
input (Tensor) | It is an input tensor for the sum. |
dim (int or tuple of ints, optional) | If None, it returns the sum of all elements. If it is an int or a tuple of ints,
It defines dimension(s) along which to sum. |
keepdim (bool) |
If True, it retains reduced dimensions with size 1; by default, it is False. |
dtype (torch.dtype, optional) | It defines the data type of the output tensor. |
Calculating the sum of all elements
If you apply the sum() function on an entire tensor, it will return a scalar value representing the total of all elements.
import torch tensor = torch.tensor([[10, 20], [40, 50]], dtype=torch.float) sum_all = torch.sum(tensor) print(sum_all) # Output: tensor(120.)
Sum along a specific dimension

If you want to sum along rows, pass dim=0. If you want to sum along columns, pass dim=1.
import torch tensor = torch.tensor([[10, 4], [9, 2]], dtype=torch.float) # Sum along rows (dim=0) row_sum = torch.sum(tensor, dim=0) print(row_sum) # Output: tensor([19., 6.]) # Sum along columns (dim=1) col_sum = torch.sum(tensor, dim=1) print(col_sum) # Output: tensor([14., 11.])
Our input tensor contains two rows and two columns. It is a 1 2×2 matrix. The row elements are 10. + 9. = 19.0. and 4. + 2. = 6.0. So, our row_sum tensor becomes [19.0, 6.].
The column elements are 10. + 4. = 14.0 and 9. + 2. = 11.0. So, our col_sum tensor becomes [14., 11.].
Sum across multiple dimensions
If you pass the dim as a tuple, you can perform a sum across multiple dimensions using a tuple.
import torch tensor = torch.tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]], dtype=torch.float) sum_multiple_dim = torch.sum(tensor, dim=(1, 2)) print("Sum across dimensions 1 and 2:") # Output: Sum across dimensions 1 and 2: print(sum_multiple_dim) # Output: tensor([10., 26.])
One thing to ensure is that you won’t specify an invalid dim because it will raise an error. Ensure dim is within the tensor’s rank.
Preserving reduced dimension
import torch tensor = torch.tensor([[11, 21], [31, 41]], dtype=torch.float) print(tensor.shape) # Output: torch.Size([2, 2]) reduced_tensor = torch.sum(tensor, dim=1, keepdim=True) print(reduced_tensor) # Output: tensor([[ 3.], # [ 7.]]) print(reduced_tensor.shape) # Output: torch.Size([2, 1])
The above output shows that the keepdim=True ensures the output shape is [2, 1] from [2, 2], which is helpful in broadcasting.
Specifying output data type
import torch tensor = torch.tensor([[1, 2], [3, 4]], dtype=torch.int32) floating_sum = torch.sum(tensor, dtype=torch.float64) print(floating_sum) # Output: tensor(10., dtype=torch.float64)
You can see that we override the input’s int32 data type with float64 for higher precision.
Mixing incompatible dtypes with the input tensor may cause precision loss or errors.
Empty Tensors

import torch empty_tensor = torch.tensor([]) empty_sum = torch.sum(empty_tensor) print(empty_sum) # Output: tensor(0.)That’s all!