广播机制#
广播机制是 PyTorch 用于处理张量间非显式算术运算的核心机制。
例如,显然无法将一个 \(3 \times 3\) 矩阵与 \(4 \times 2\) 矩阵相加,这将导致错误。然而,将一个标量加到 \(3 \times 3\) 矩阵上,或者将一个长度为 3 的向量加到 \(3 \times 3\) 矩阵上,虽然逻辑不总是直观,但却是可行的。
PyTorch 的广播机制基于一组简单的规则,在操作张量时必须了解这些规则。
广播规则#
两个张量要满足可广播的条件,必须遵循以下规则:
每个张量至少有一个维度。
从最后一个维度开始遍历,每个维度的大小必须满足:
相等;或
其中一个为 1;或
其中一个不存在(即维度缺失)。
我们通过示例来进一步说明:
import torch
# Deux tenseurs de la même taille sont toujours broadcastables
x=torch.zeros(5,7,3)
y=torch.zeros(5,7,3)
# Les deux tenseurs suivants ne sont pas broadcastables car x n'a pas au moins une dimension
x=torch.zeros((0,))
y=torch.zeros(2,2)
# On aligne les dimensions visuellement pour voir si les tenseurs sont broadcastables
# En partant de la droite,
# 1. x et y ont la même taille et sont de taille 1
# 2. y est de taille 1
# 3. x et y ont la même taille
# 4. la dimension de y n'existe pas
# Les deux tenseurs sont donc broadcastables
x=torch.zeros(5,3,4,1)
y=torch.zeros( 3,1,1)
# A l'inverse, ces deux tenseurs ne sont pas broadcastables car 3. x et y n'ont pas la même taille
x=torch.zeros(5,2,4,1)
y=torch.zeros( 3,1,1)
现在我们知道如何判断两个张量是否可广播,接下来定义它们在运算时的具体规则:
广播规则如下:
规则 1:若张量
x和y的维度数不同,则在维度较少的张量前添加1,使两者维度对齐。规则 2:对于每个维度,结果的大小为
x和y对应维度大小的最大值。
广播时,维度较小的张量会沿对应维度复制,以匹配较大张量的形状。
注意:
如果两个张量不可广播,相加会直接报错。
即使广播成功,结果也可能不是预期的。因此,理解这些规则至关重要。
我们重新审视之前的两个示例:
示例 1:将一个标量加到 \(3 \times 3\) 矩阵上:
x=torch.randn(3,3)
y=torch.tensor(1)
print("x : " ,x)
print("y : " ,y)
print("x+y : " ,x+y)
print("x+y shape : ",(x+y).shape)
# Le tenseur y est broadcasté pour avoir la même taille que x, il se transforme en tenseur de 1 de taille 3x3
x : tensor([[ 0.6092, -0.6887, 0.3060],
[ 1.3496, 1.7739, -0.4011],
[-0.8876, 0.7196, -0.3810]])
y : tensor(1)
x+y : tensor([[1.6092, 0.3113, 1.3060],
[2.3496, 2.7739, 0.5989],
[0.1124, 1.7196, 0.6190]])
x+y shape : torch.Size([3, 3])
示例 2:将一个长度为 3 的向量加到 \(3 \times 3\) 矩阵上:
x=torch.randn(3,3)
y=torch.tensor([1,2,3]) # tenseur de taille 3
print("x : " ,x)
print("y : " ,y)
print("x+y : " ,x+y)
print("x+y shape : ",(x+y).shape)
# Le tenseur y est broadcasté pour avoir la même taille que x, il se transforme en tenseur de 1 de taille 3x3
x : tensor([[ 0.9929, -0.1435, 1.5740],
[ 1.2143, 1.3366, 0.6415],
[-0.2718, 0.3497, -0.2650]])
y : tensor([1, 2, 3])
x+y : tensor([[1.9929, 1.8565, 4.5740],
[2.2143, 3.3366, 3.6415],
[0.7282, 2.3497, 2.7350]])
x+y shape : torch.Size([3, 3])
接下来,我们分析一些更复杂的示例:
x=torch.zeros(5,3,4,1)
y=torch.zeros( 3,1,1)
print("x+y shape : ",(x+y).shape)
# Le tenseur y a été étendu en taille 1x3x1x1 (règle 1) puis dupliqué en taille 5x3x4x1 (règle 2)
x+y shape : torch.Size([5, 3, 4, 1])
x=torch.empty(1)
y=torch.empty(3,1,7)
print("x+y shape : ",(x+y).shape)
# Le tenseur y a été étendu en taille 1x1x1 (règle 1) puis dupliqué en taille 3x1x7 (règle 2)
x+y shape : torch.Size([3, 1, 7])
x=torch.empty(5,2,4,1)
y=torch.empty(3,1,1)
print("x+y shape : ",(x+y).shape)
# L'opération n'est pas possible car les tenseurs ne sont pas broadcastables (dimension 3 en partant de la fin ne correspond pas)
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[18], line 3
1 x=torch.empty(5,2,4,1)
2 y=torch.empty(3,1,1)
----> 3 print("x+y shape : ",(x+y).shape)
4 # L'opération n'est pas possible car les tenseurs ne sont pas broadcastables (dimension 3 en partant de la fin ne correspond pas)
RuntimeError: The size of tensor a (2) must match the size of tensor b (3) at non-singleton dimension 1
其他注意事项#
与标量的比较#
虽然容易被忽略,但广播机制能让我们方便地进行比较操作。
a = torch.tensor([10., 0, -4])
print(a > 0)
print(a==0)
tensor([ True, False, False])
tensor([False, True, False])
我们还可以直接比较两个张量:
a=torch.tensor([1,2,3])
b=torch.tensor([4,2,6])
# Comparaison élément par élément
print(a==b)
# Comparaison élément par élément et égalité pour tous les éléments
print((a==b).all())
# Comparaison élément par élément et égalité pour au moins un élément
print((a==b).any())
# Comparaison avec supérieur ou égal
print(a>=b)
tensor([False, True, False])
tensor(False)
tensor(True)
tensor([False, True, False])
这在创建基于阈值的掩码,或验证两个操作是否等价时非常有用。
使用 unsqueeze()#
之前我们看到,可以将一个长度为 3 的张量广播到 \(3 \times 3\) 矩阵上。PyTorch 会自动将其转换为 \(1 \times 3\) 以执行操作。但有时我们需要反向操作,比如将一个 \(3 \times 1\) 的张量加到 \(3 \times 3\) 的矩阵上。
此时,需要手动应用规则 1,通过 unsqueeze() 函数为张量添加一个维度。
x=torch.tensor([1,2,3])
y=torch.randn(3,3)
print("y : ",y )
print("x+y : ",x+y)
x=x.unsqueeze(1)
print("x shape : ",x.shape)
print("x+y : ",x+y)
y : tensor([[ 1.3517, 1.1880, 0.4483],
[ 0.5137, -0.5406, -0.1412],
[-0.0108, 1.3757, 0.6112]])
x+y : tensor([[2.3517, 3.1880, 3.4483],
[1.5137, 1.4594, 2.8588],
[0.9892, 3.3757, 3.6112]])
x shape : torch.Size([3, 1])
x+y : tensor([[2.3517, 2.1880, 1.4483],
[2.5137, 1.4594, 1.8588],
[2.9892, 4.3757, 3.6112]])
如您所见,我们通过这种方式绕过了 PyTorch 的默认广播规则,得到了所需的结果。
注意:
PyTorch 的规则 1等价于不断调用
x.unsqueeze(0),直到两个张量的维度数相同。也可以用
None替代unsqueeze(),如下所示:
x=torch.tensor([1,2,3])
# La première opération est l'équivalent de unsqueeze(0) et la seconde de unsqueeze(1)
x[None].shape,x[...,None].shape
(torch.Size([1, 3]), torch.Size([3, 1]))
使用 keepdim#
PyTorch 中用于沿某一维度缩减张量的函数(如 torch.sum 求和、torch.mean 计算均值等)提供了一个实用参数 keepdim,在某些情况下非常有用。
这些操作会默认删除所操作的维度,从而改变张量的形状。
x=torch.randn(3,4,5)
print(x.shape)
x=x.sum(dim=1) # somme sur la dimension 1
print(x.shape)
torch.Size([3, 4, 5])
torch.Size([3, 5])
如果希望保留所操作的维度,可以设置参数 keepdim=True。
x=torch.randn(3,4,5)
print(x.shape)
x=x.sum(dim=1,keepdim=True) # somme sur la dimension 1
print(x.shape)
torch.Size([3, 4, 5])
torch.Size([3, 1, 5])
这在避免维度错误时非常有用。我们来看一个会影响广播行为的示例:
x=torch.randn(3,4,5)
y=torch.randn(1,1,1)
x_sum=x.sum(dim=1)
x_sum_keepdim=x.sum(dim=1,keepdim=True)
print("Les deux opérations sont elles équivalentes ? :",(x_sum+y==x_sum_keepdim+y).all().item())
Les deux opérations sont elles équivalentes ? : False
具体过程如下:
第一种情况:
x_sum的形状为 \(3 \times 5\)。根据规则 1,它被转换为 \(1 \times 3 \times 5\),而y则通过规则 2 被广播为 \(1 \times 3 \times 5\)。第二种情况:
x_sum_keepdim的形状为 \(3 \times 1 \times 5\),y仍通过规则 2 被广播为 \(1 \times 3 \times 5\)。
爱因斯坦求和约定#
虽然这部分与广播机制没有直接关联,但了解它非常重要。
在 PyTorch 中,我们一直使用 @ 运算符(或 torch.matmul)进行矩阵乘法。另一种更灵活的方法是使用爱因斯坦求和约定(torch.einsum),它是一种用于表示乘积和求和的简洁符号。
例如:
ik,kj -> ij
左侧表示输入张量的维度,用逗号分隔。此处有两个张量,分别为 i,k 和 k,j。右侧表示输出张量的维度,即 i,j。
爱因斯坦求和约定的规则:
左侧重复出现但右侧未出现的索引,表示对该索引求和。
每个索引在左侧最多出现两次。
左侧未重复的索引必须在右侧出现。
它可用于多种操作:
torch.einsum('ij->ji', a) # 返回矩阵 a 的转置
而
torch.einsum('bi,ij,bj->b', a, b, c)
则返回一个长度为 b 的向量,其第 k 个元素为 \(a[k,i] \cdot b[i,j] \cdot c[k,j]\) 的和。这种表示法在处理多维批次数据时尤为方便。例如,对两个批次的矩阵进行逐批次矩阵乘法,可以使用:
torch.einsum('bik,bkj->bij', a, b)
这是 PyTorch 中进行矩阵乘法的实用方法。此外,它运行高效,通常是实现自定义操作的最优选择。