Broadcasting#
Broadcasting is a mechanism used by PyTorch to handle tensors during non-trivial arithmetic operations.
For example, it is clear that you cannot add a \(3 \times 3\) matrix to a \(4 \times 2\) matrix, which would result in an error. However, adding a scalar to a \(3 \times 3\) matrix or a vector of size \(3\) to a \(3 \times 3\) matrix is possible, even if the logic is not always obvious.
PyTorch’s broadcasting relies on simple rules to know when manipulating tensors.
Broadcasting Rules#
For two tensors to be broadcastable, they must follow these rules:
Each tensor must have at least one dimension.
When iterating over the dimension sizes (starting from the last), the sizes must be equal, or one of them must be 1, or one of them must be absent.
Let’s use examples to clarify:
import torch
# Deux tenseurs de la même taille sont toujours broadcastables
x=torch.zeros(5,7,3)
y=torch.zeros(5,7,3)
# Les deux tenseurs suivants ne sont pas broadcastables car x n'a pas au moins une dimension
x=torch.zeros((0,))
y=torch.zeros(2,2)
# On aligne les dimensions visuellement pour voir si les tenseurs sont broadcastables
# En partant de la droite,
# 1. x et y ont la même taille et sont de taille 1
# 2. y est de taille 1
# 3. x et y ont la même taille
# 4. la dimension de y n'existe pas
# Les deux tenseurs sont donc broadcastables
x=torch.zeros(5,3,4,1)
y=torch.zeros( 3,1,1)
# A l'inverse, ces deux tenseurs ne sont pas broadcastables car 3. x et y n'ont pas la même taille
x=torch.zeros(5,2,4,1)
y=torch.zeros( 3,1,1)
Now that we know how to identify two broadcastable tensors, let’s define the rules applied during the operation between them.
The rules are:
Rule 1: If the number of dimensions of x and y differ, add 1 to the beginning of the dimensions of the tensor with fewer dimensions to align them.
Rule 2: For each dimension size, the resulting size is the maximum of the sizes of x and y.
The tensor whose size is modified will be duplicated as needed to match.
Note: If two tensors are not broadcastable, their addition will result in an error. However, in many cases, broadcasting will work but will not produce the desired result. This is why it is crucial to master these rules.
Let’s revisit our two examples:
Adding a scalar to a \(3 \times 3\) matrix:
x=torch.randn(3,3)
y=torch.tensor(1)
print("x : " ,x)
print("y : " ,y)
print("x+y : " ,x+y)
print("x+y shape : ",(x+y).shape)
# Le tenseur y est broadcasté pour avoir la même taille que x, il se transforme en tenseur de 1 de taille 3x3
x : tensor([[ 0.6092, -0.6887, 0.3060],
[ 1.3496, 1.7739, -0.4011],
[-0.8876, 0.7196, -0.3810]])
y : tensor(1)
x+y : tensor([[1.6092, 0.3113, 1.3060],
[2.3496, 2.7739, 0.5989],
[0.1124, 1.7196, 0.6190]])
x+y shape : torch.Size([3, 3])
Adding a vector of size \(3\) to a \(3 \times 3\) matrix:
x=torch.randn(3,3)
y=torch.tensor([1,2,3]) # tenseur de taille 3
print("x : " ,x)
print("y : " ,y)
print("x+y : " ,x+y)
print("x+y shape : ",(x+y).shape)
# Le tenseur y est broadcasté pour avoir la même taille que x, il se transforme en tenseur de 1 de taille 3x3
x : tensor([[ 0.9929, -0.1435, 1.5740],
[ 1.2143, 1.3366, 0.6415],
[-0.2718, 0.3497, -0.2650]])
y : tensor([1, 2, 3])
x+y : tensor([[1.9929, 1.8565, 4.5740],
[2.2143, 3.3366, 3.6415],
[0.7282, 2.3497, 2.7350]])
x+y shape : torch.Size([3, 3])
Let’s now examine some more complex examples:
x=torch.zeros(5,3,4,1)
y=torch.zeros( 3,1,1)
print("x+y shape : ",(x+y).shape)
# Le tenseur y a été étendu en taille 1x3x1x1 (règle 1) puis dupliqué en taille 5x3x4x1 (règle 2)
x+y shape : torch.Size([5, 3, 4, 1])
x=torch.empty(1)
y=torch.empty(3,1,7)
print("x+y shape : ",(x+y).shape)
# Le tenseur y a été étendu en taille 1x1x1 (règle 1) puis dupliqué en taille 3x1x7 (règle 2)
x+y shape : torch.Size([3, 1, 7])
x=torch.empty(5,2,4,1)
y=torch.empty(3,1,1)
print("x+y shape : ",(x+y).shape)
# L'opération n'est pas possible car les tenseurs ne sont pas broadcastables (dimension 3 en partant de la fin ne correspond pas)
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[18], line 3
1 x=torch.empty(5,2,4,1)
2 y=torch.empty(3,1,1)
----> 3 print("x+y shape : ",(x+y).shape)
4 # L'opération n'est pas possible car les tenseurs ne sont pas broadcastables (dimension 3 en partant de la fin ne correspond pas)
RuntimeError: The size of tensor a (2) must match the size of tensor b (3) at non-singleton dimension 1
Other Points to Consider#
Comparison with Scalars#
It’s not always obvious, but this allows for simple comparisons.
a = torch.tensor([10., 0, -4])
print(a > 0)
print(a==0)
tensor([ True, False, False])
tensor([False, True, False])
You can also compare two tensors with each other:
a=torch.tensor([1,2,3])
b=torch.tensor([4,2,6])
# Comparaison élément par élément
print(a==b)
# Comparaison élément par élément et égalité pour tous les éléments
print((a==b).all())
# Comparaison élément par élément et égalité pour au moins un élément
print((a==b).any())
# Comparaison avec supérieur ou égal
print(a>=b)
tensor([False, True, False])
tensor(False)
tensor(True)
tensor([False, True, False])
This can be very useful for creating masks from a threshold, for example, or verifying that two operations are equivalent.
Using unsqueeze()
#
We saw earlier that it is possible to broadcast a tensor of size \(3\) to a matrix of size \(3 \times 3\). PyTorch’s broadcasting automatically transforms it into size \(1 \times 3\) to perform the operation. However, you might want to perform the operation in the other direction, i.e., adding a tensor \(3 \times 1\) to a matrix of size \(3 \times 3\).
In this case, you need to manually replace Rule 1 using the unsqueeze() function, which allows adding a dimension.
x=torch.tensor([1,2,3])
y=torch.randn(3,3)
print("y : ",y )
print("x+y : ",x+y)
x=x.unsqueeze(1)
print("x shape : ",x.shape)
print("x+y : ",x+y)
y : tensor([[ 1.3517, 1.1880, 0.4483],
[ 0.5137, -0.5406, -0.1412],
[-0.0108, 1.3757, 0.6112]])
x+y : tensor([[2.3517, 3.1880, 3.4483],
[1.5137, 1.4594, 2.8588],
[0.9892, 3.3757, 3.6112]])
x shape : torch.Size([3, 1])
x+y : tensor([[2.3517, 2.1880, 1.4483],
[2.5137, 1.4594, 1.8588],
[2.9892, 4.3757, 3.6112]])
As you can see, we were able to bypass PyTorch’s rules to get the desired result.
Note:
PyTorch’s Rule 1 is equivalent to applying x.unsqueeze(0) until the number of dimensions is the same.
It is possible to replace unsqueeze() with None as follows:
x=torch.tensor([1,2,3])
# La première opération est l'équivalent de unsqueeze(0) et la seconde de unsqueeze(1)
x[None].shape,x[...,None].shape
(torch.Size([1, 3]), torch.Size([3, 1]))
Using keepdim
#
PyTorch functions that reduce the size of a tensor along a dimension (torch.sum to sum along a dimension, torch.mean to calculate the mean, etc.) have an interesting parameter to use in certain cases.
These operations modify the dimension of the tensor and automatically remove the dimension along which the operation was performed.
x=torch.randn(3,4,5)
print(x.shape)
x=x.sum(dim=1) # somme sur la dimension 1
print(x.shape)
torch.Size([3, 4, 5])
torch.Size([3, 5])
If you want to keep the dimension along which the sum is performed, you can use the argument keepdim=True.
x=torch.randn(3,4,5)
print(x.shape)
x=x.sum(dim=1,keepdim=True) # somme sur la dimension 1
print(x.shape)
torch.Size([3, 4, 5])
torch.Size([3, 1, 5])
This can be very useful to avoid errors with dimensions. Let’s examine a case where this affects broadcasting.
x=torch.randn(3,4,5)
y=torch.randn(1,1,1)
x_sum=x.sum(dim=1)
x_sum_keepdim=x.sum(dim=1,keepdim=True)
print("Les deux opérations sont elles équivalentes ? :",(x_sum+y==x_sum_keepdim+y).all().item())
Les deux opérations sont elles équivalentes ? : False
Here’s what happened:
In the first case, x_sum has a size of \(3 \times 5\). Rule 1 transforms it into \(1 \times 3 \times 5\), and Rule 2 transforms y into \(1 \times 3 \times 5\).
In the second case, x_sum_keepdim has a size of \(3 \times 1 \times 5\), and Rule 2 transforms y into \(1 \times 3 \times 5\).
Einstein Notation#
This section is not directly related to broadcasting, but it is important to know.
To multiply matrices in PyTorch, we have used the @ operator (or torch.matmul) so far. There is another method to perform matrix multiplications with the Einstein Summation (torch.einsum).
This is a compact notation to express products and sums, for example: ik,kj -> ij The left side represents the dimensions of the inputs, separated by commas. Here, we have two tensors each with two dimensions (i,k and k,j). The right side represents the dimensions of the result, i.e., a tensor of dimensions i,j.
The rules of Einstein notation are:
Repeated indices on the left are implicitly summed if they do not appear on the right.
Each index can appear at most twice on the left.
Non-repeated indices on the left must appear on the right.
You can use it for various operations:
torch.einsum('ij->ji', a)
returns the transpose of the matrix a.
Whereas
torch.einsum('bi,ij,bj->b', a, b, c)
returns a vector of size b where the k-th coordinate is the sum of \(a[k,i]⋅b[i,j]⋅c[k,j]\). This notation is particularly practical when you handle batches with multiple dimensions. For example, if you have two batches of matrices and you want to compute the matrix product per batch, you can use:
torch.einsum('bik,bkj->bij', a, b)
This is a practical method for performing matrix multiplications in PyTorch. Moreover, it is very fast and often the most efficient way to perform custom operations in PyTorch.