The torch.prod() method calculates the product of all elements in a tensor or along a specified dimension. It avoids overflow issues common in large products by using internal optimizations.

Syntax
torch.prod(input, dim=None, keepdim=False, dtype=None, out=None)
Parameters
Argument | Description |
input (Tensor) | It is an input tensor. |
dim (int or tuple of ints, optional) |
It is a dimension(s) along which to calculate the product. |
keepdim (bool, optional) |
It suggests whether to retain the reduced dimension(s) in the output tensor. |
dtype (torch.dtype, optional) | It is the desired data type of the output tensor. |
out (Tensor, optional) | If you have a pre-allocated tensor, it is the output tensor where the result will be stored. |
Product of all elements in a Tensor
Let’s calculate the product of all elements in the input tensor, flattening it into a scalar.
import torch tensor = torch.tensor([1, 5, 4, 3]) product = torch.prod(tensor) print(product) # Output: tensor(60)
You can see from the above program that without specifying dim, it multiplies all elements, returning a scalar tensor.
Product along a specific dimension
Let’s calculate the product across a specified dimension, reducing it to a single value.
If we are working with a 2D tensor, we will have two choices for finding across dimensions:
- dim=0: It will calculate the product across the rows of the tensor.
- dim=1: It will calculate the product across the columns of the tensor.
Along dim=0 (across rows)

import torch tensor = torch.tensor([[5, 6, 3], [1, 8, 7]]) # Product across dim=0 (across rows) product_across_rows = torch.prod(tensor, dim=0) print(product_across_rows) # Output: tensor([ 5, 48, 21]) (5*1, 6*8, 3*7)You can see from the above output that for dim=0, it multiplies elements column-wise.
Along dim=1 (across columns)

import torch tensor = torch.tensor([[5, 6, 3], [1, 8, 7]]) # Product along dim=1 (across columns) product_across_columns = torch.prod(tensor, dim=1) print(product_across_columns) # Output: tensor([90, 56])You can see from the above output that for dim=1, it multiplies elements row-wise.
Preserve reduced dimensions (keepdim=True)
Let’s retain the reduced dimension as size 1 in the output.
import torch tensor = torch.tensor([[5, 6], [1, 8]]) product_reduce = torch.prod(tensor, dim=1, keepdim=True) print(product_reduce) # Output: tensor([[30], # [ 8]])
You can see from the above output that keepdim=True preserves the dimension structure, which helps maintain broadcast or tensor shape compatibility.
Specifying output data type
import torch tensor = torch.tensor([2.5, 12.9], dtype=torch.float32) int_product = torch.prod(tensor, dtype=torch.int64) print(int_product) # Output: tensor(24)
The dtype argument ensures the output is in the desired type. Since we need an integer, it can be truncated or cast as required.
Handling zero and negative values
Let’s handle edge cases, such as zeros or negative numbers.
import torch tensor = torch.tensor([2, 0, -3]) zero_negative_product = torch.prod(tensor) print(zero_negative_product) # Output: tensor(0)
Since we have one element as 0, the product will now become 0 because anything multiplied by 0 is 0. Negative values are multiplied as expected.