广播机制#

广播机制是 PyTorch 用于处理张量间非显式算术运算的核心机制。

例如,显然无法将一个 \(3 \times 3\) 矩阵与 \(4 \times 2\) 矩阵相加,这将导致错误。然而,将一个标量加到 \(3 \times 3\) 矩阵上,或者将一个长度为 3 的向量加到 \(3 \times 3\) 矩阵上,虽然逻辑不总是直观,但却是可行的。

PyTorch 的广播机制基于一组简单的规则,在操作张量时必须了解这些规则。

广播规则#

两个张量要满足可广播的条件,必须遵循以下规则:

  • 每个张量至少有一个维度。

  • 从最后一个维度开始遍历,每个维度的大小必须满足:

    • 相等;或

    • 其中一个为 1;或

    • 其中一个不存在(即维度缺失)。

我们通过示例来进一步说明:

import torch

# Deux tenseurs de la même taille sont toujours broadcastables
x=torch.zeros(5,7,3)
y=torch.zeros(5,7,3)

# Les deux tenseurs suivants ne sont pas broadcastables car x n'a pas au moins une dimension
x=torch.zeros((0,))
y=torch.zeros(2,2)

# On aligne les dimensions visuellement pour voir si les tenseurs sont broadcastables
# En partant de la droite,
# 1. x et y ont la même taille et sont de taille 1
# 2. y est de taille 1
# 3. x et y ont la même taille
# 4. la dimension de y n'existe pas
# Les deux tenseurs sont donc broadcastables
x=torch.zeros(5,3,4,1)
y=torch.zeros(  3,1,1)

# A l'inverse, ces deux tenseurs ne sont pas broadcastables car 3. x et y n'ont pas la même taille
x=torch.zeros(5,2,4,1)
y=torch.zeros(  3,1,1)

现在我们知道如何判断两个张量是否可广播,接下来定义它们在运算时的具体规则:

广播规则如下:

  • 规则 1:若张量 xy 的维度数不同,则在维度较少的张量前添加 1,使两者维度对齐。

  • 规则 2:对于每个维度,结果的大小为 xy 对应维度大小的最大值。

广播时,维度较小的张量会沿对应维度复制,以匹配较大张量的形状。

注意

  • 如果两个张量不可广播,相加会直接报错。

  • 即使广播成功,结果也可能不是预期的。因此,理解这些规则至关重要。

我们重新审视之前的两个示例:

示例 1:将一个标量加到 \(3 \times 3\) 矩阵上:

x=torch.randn(3,3)
y=torch.tensor(1)
print("x : " ,x)
print("y : " ,y)
print("x+y : " ,x+y)
print("x+y shape : ",(x+y).shape)
# Le tenseur y est broadcasté pour avoir la même taille que x, il se transforme en tenseur de 1 de taille 3x3
x :  tensor([[ 0.6092, -0.6887,  0.3060],
        [ 1.3496,  1.7739, -0.4011],
        [-0.8876,  0.7196, -0.3810]])
y :  tensor(1)
x+y :  tensor([[1.6092, 0.3113, 1.3060],
        [2.3496, 2.7739, 0.5989],
        [0.1124, 1.7196, 0.6190]])
x+y shape :  torch.Size([3, 3])

示例 2:将一个长度为 3 的向量加到 \(3 \times 3\) 矩阵上:

x=torch.randn(3,3)
y=torch.tensor([1,2,3]) # tenseur de taille 3
print("x : " ,x)
print("y : " ,y)
print("x+y : " ,x+y)
print("x+y shape : ",(x+y).shape)
# Le tenseur y est broadcasté pour avoir la même taille que x, il se transforme en tenseur de 1 de taille 3x3
x :  tensor([[ 0.9929, -0.1435,  1.5740],
        [ 1.2143,  1.3366,  0.6415],
        [-0.2718,  0.3497, -0.2650]])
y :  tensor([1, 2, 3])
x+y :  tensor([[1.9929, 1.8565, 4.5740],
        [2.2143, 3.3366, 3.6415],
        [0.7282, 2.3497, 2.7350]])
x+y shape :  torch.Size([3, 3])

接下来,我们分析一些更复杂的示例:

x=torch.zeros(5,3,4,1)
y=torch.zeros(  3,1,1)
print("x+y shape : ",(x+y).shape)
# Le tenseur y a été étendu en taille 1x3x1x1 (règle 1) puis dupliqué en taille 5x3x4x1 (règle 2)
x+y shape :  torch.Size([5, 3, 4, 1])
x=torch.empty(1)
y=torch.empty(3,1,7)
print("x+y shape : ",(x+y).shape)
# Le tenseur y a été étendu en taille 1x1x1 (règle 1) puis dupliqué en taille 3x1x7 (règle 2)
x+y shape :  torch.Size([3, 1, 7])
x=torch.empty(5,2,4,1)
y=torch.empty(3,1,1)
print("x+y shape : ",(x+y).shape)
# L'opération n'est pas possible car les tenseurs ne sont pas broadcastables (dimension 3 en partant de la fin ne correspond pas)
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[18], line 3
      1 x=torch.empty(5,2,4,1)
      2 y=torch.empty(3,1,1)
----> 3 print("x+y shape : ",(x+y).shape)
      4 # L'opération n'est pas possible car les tenseurs ne sont pas broadcastables (dimension 3 en partant de la fin ne correspond pas)

RuntimeError: The size of tensor a (2) must match the size of tensor b (3) at non-singleton dimension 1

其他注意事项#

与标量的比较#

虽然容易被忽略,但广播机制能让我们方便地进行比较操作。

a = torch.tensor([10., 0, -4])
print(a > 0)
print(a==0)
tensor([ True, False, False])
tensor([False,  True, False])

我们还可以直接比较两个张量:

a=torch.tensor([1,2,3])
b=torch.tensor([4,2,6])
# Comparaison élément par élément
print(a==b)
# Comparaison élément par élément et égalité pour tous les éléments
print((a==b).all())
# Comparaison élément par élément et égalité pour au moins un élément
print((a==b).any())
# Comparaison avec supérieur ou égal
print(a>=b)
tensor([False,  True, False])
tensor(False)
tensor(True)
tensor([False,  True, False])

这在创建基于阈值的掩码,或验证两个操作是否等价时非常有用。

使用 unsqueeze()#

之前我们看到,可以将一个长度为 3 的张量广播到 \(3 \times 3\) 矩阵上。PyTorch 会自动将其转换为 \(1 \times 3\) 以执行操作。但有时我们需要反向操作,比如将一个 \(3 \times 1\) 的张量加到 \(3 \times 3\) 的矩阵上。

此时,需要手动应用规则 1,通过 unsqueeze() 函数为张量添加一个维度。

x=torch.tensor([1,2,3])
y=torch.randn(3,3)
print("y : ",y )
print("x+y : ",x+y) 

x=x.unsqueeze(1)
print("x shape : ",x.shape)
print("x+y : ",x+y)
y :  tensor([[ 1.3517,  1.1880,  0.4483],
        [ 0.5137, -0.5406, -0.1412],
        [-0.0108,  1.3757,  0.6112]])
x+y :  tensor([[2.3517, 3.1880, 3.4483],
        [1.5137, 1.4594, 2.8588],
        [0.9892, 3.3757, 3.6112]])
x shape :  torch.Size([3, 1])
x+y :  tensor([[2.3517, 2.1880, 1.4483],
        [2.5137, 1.4594, 1.8588],
        [2.9892, 4.3757, 3.6112]])

如您所见,我们通过这种方式绕过了 PyTorch 的默认广播规则,得到了所需的结果。

注意

  • PyTorch 的规则 1等价于不断调用 x.unsqueeze(0),直到两个张量的维度数相同。

  • 也可以用 None 替代 unsqueeze(),如下所示:

x=torch.tensor([1,2,3])
# La première opération est l'équivalent de unsqueeze(0) et la seconde de unsqueeze(1)
x[None].shape,x[...,None].shape
(torch.Size([1, 3]), torch.Size([3, 1]))

使用 keepdim#

PyTorch 中用于沿某一维度缩减张量的函数(如 torch.sum 求和、torch.mean 计算均值等)提供了一个实用参数 keepdim,在某些情况下非常有用。

这些操作会默认删除所操作的维度,从而改变张量的形状。

x=torch.randn(3,4,5)
print(x.shape)
x=x.sum(dim=1) # somme sur la dimension 1
print(x.shape)
torch.Size([3, 4, 5])
torch.Size([3, 5])

如果希望保留所操作的维度,可以设置参数 keepdim=True

x=torch.randn(3,4,5)
print(x.shape)
x=x.sum(dim=1,keepdim=True) # somme sur la dimension 1
print(x.shape)
torch.Size([3, 4, 5])
torch.Size([3, 1, 5])

这在避免维度错误时非常有用。我们来看一个会影响广播行为的示例:

x=torch.randn(3,4,5)
y=torch.randn(1,1,1)
x_sum=x.sum(dim=1)
x_sum_keepdim=x.sum(dim=1,keepdim=True)
print("Les deux opérations sont elles équivalentes ? :",(x_sum+y==x_sum_keepdim+y).all().item())
Les deux opérations sont elles équivalentes ? : False

具体过程如下:

  • 第一种情况x_sum 的形状为 \(3 \times 5\)。根据规则 1,它被转换为 \(1 \times 3 \times 5\),而 y 则通过规则 2 被广播为 \(1 \times 3 \times 5\)

  • 第二种情况x_sum_keepdim 的形状为 \(3 \times 1 \times 5\)y 仍通过规则 2 被广播为 \(1 \times 3 \times 5\)

爱因斯坦求和约定#

虽然这部分与广播机制没有直接关联,但了解它非常重要。

在 PyTorch 中,我们一直使用 @ 运算符(或 torch.matmul)进行矩阵乘法。另一种更灵活的方法是使用爱因斯坦求和约定torch.einsum),它是一种用于表示乘积和求和的简洁符号。

例如: ik,kj -> ij 左侧表示输入张量的维度,用逗号分隔。此处有两个张量,分别为 i,kk,j。右侧表示输出张量的维度,即 i,j

爱因斯坦求和约定的规则:

  • 左侧重复出现但右侧未出现的索引,表示对该索引求和。

  • 每个索引在左侧最多出现两次。

  • 左侧未重复的索引必须在右侧出现。

它可用于多种操作:

torch.einsum('ij->ji', a)  # 返回矩阵 a 的转置

torch.einsum('bi,ij,bj->b', a, b, c)

则返回一个长度为 b 的向量,其第 k 个元素为 \(a[k,i] \cdot b[i,j] \cdot c[k,j]\) 的和。这种表示法在处理多维批次数据时尤为方便。例如,对两个批次的矩阵进行逐批次矩阵乘法,可以使用:

torch.einsum('bik,bkj->bij', a, b)

这是 PyTorch 中进行矩阵乘法的实用方法。此外,它运行高效,通常是实现自定义操作的最优选择。