einsum:愛因斯坦求和約定


在Tensorflow、Numpy和PyTorch中都提供了使用einsum的api,einsum是一種能夠簡潔表示點積、外積、轉置、矩陣-向量乘法、矩陣-矩陣乘法等運算的領域特定語言。在Tensorflow等計算框架中使用einsum,操作矩陣運算時可以免於記憶和使用特定的函數,並且使得代碼簡潔,高效。

如對矩陣\(A\in \mathbb{R}^{I×K}​\)和矩陣\(B\in \mathbb{R}^{K×J}​\)做矩陣乘,然后對列求和,最終得到向量\(c\in \mathbb{R}^J​\),即:

\[\mathbb{R}^{I×K}\bigotimes \mathbb{R}^{K×J}\to \mathbb{R}^{I×J}\to \mathbb{R}^{J} \]

使用愛因斯坦求和約定表示為:

\[c_j=\sum_i\sum_kA_{ik}B_{kj}=A_{ik}B_{kj} \]

在Tensorflow、Numpy和PyTorch中對應的einsum字符串為:

ik,kj->j

在上面的字符串中,隱式地省略了重復的下標\(k\),表示在該維度矩陣乘;另外輸出中未指明下標\(i\),表示在該維度累加。

Numpy、PyTorch和Tensorflow中的einsum

einsum在Numpy中的實現為np.einsum,在PyTorch中的實現為torch.einsum,在Tensorflow中的實現為tf.einsum,均使用同樣的函數簽名einsum(equation,operands),其中,equation傳入愛因斯坦求和約定的字符串,而operands則是張量序列。在Numpy、Tensorflow中是變長參數列表,而在PyTorch中是列表。上述例子中,在Tensorflow中可寫作:

tf.einsum('ik,kj->j',mat1,mat2)

其中,mat1、mat2為執行該運算的兩個張量。注意:這里的(i,j,k)的命名是任意的,但在一個表達式中要一致。

PyTorch和Tensorflow像Numpy支持einsum的好處之一就是,einsum可以用於深度網絡架構的任意計算圖,並且可以反向傳播。在Numpy和Tensorflow中的調用格式如下:

\[result=\mathop{einsum}('\square \square, \square \square \square,\square \square\to \square \square',arg1,arg2,arg3) \]

其中,\(\square\)是占位符,表示張量維度;arg1,arg3是矩陣,arg2是三階張量,運算結果是矩陣。注意:einsum處理可變數量的輸入。上面例子中,einsum制定了三個參數的操作,但同樣可以操作一個參數、兩個參數和三個參數及以上的操作。

典型的einsum表達式

前置知識

  • 內積

    又稱點積、點乘,對應位置數字相乘,結果是一個標量,有見向量內積和矩陣內積等。

    向量\(\vec a\)和向量\(\vec b\)的內積:

    \[\vec a=[a_1,a_2,...,a_n]\\ \vec b=[b_1,b_2,...,b_n]\\ \vec a\cdot \vec b^T=a_1b_1+a_2b_2+...+a_nb_n \]

    內積幾何意義:

    \[\vec a \cdot \vec b^T=|\vec a||\vec b|\mathop{cos}\theta \]

  • 外積

    又稱叉乘、叉積、向量積,行向量矩陣乘列向量,結果是二階張量。注意到:張量的外積作為張量積的同義詞。外積是一種特殊的克羅內克積。

    向量\(\vec a\)和向量\(\vec b\)的外積:

    \[\begin{bmatrix} b_1 \\b_2 \\ b_3 \\ b_4 \end{bmatrix}\bigotimes[a_1,a_2,a_3]=\begin{bmatrix} a_1b_1 & a_2b_1 & a_3b_1 \\ a_1b_2 & a_2b_2 & a_3b_2 \\ a_1b_3 & a_2b_3 & a_3b_3 \\ a_1b_4 & a_2b_4 & a_3b_4 \\ \end{bmatrix} \]

    外積的幾何意義:

    \[\vec a=(x_1,y_1,z_1)\\ \vec b=(x_2,y_2,z_2)\\ \vec a\bigotimes\vec b=\begin{vmatrix} i & j & k\\ x_1 & y_1 & z_1\\ x_2 & y_2 & z_2 \end{vmatrix}=(y_1z_2-y_2z_1)\vec i-(x_1z_2-x_2z_1)\vec j+(x_1y_2-x_2y_1)\vec k \]

    其中,

    \[\vec i=(1,0,0)\\ \vec j=(0,1,0)\\ \vec k=(0,0,1) \]

由於PyTorch可以實時輸出運算結果,以PyTorch使用einsum表達式為例。

  • 矩陣轉置

    \[B_{ji}=A_{ij} \]

    a=torch.arange(6).reshape(2,3)
    
    >>>tensor([[0, 1, 2],
               [3, 4, 5]])
    
    torch.einsum('ij->ji',[a])
    
    >>>tensor([[0, 3],
              [1, 4],
              [2, 5]])
    
  • 求和

    \[b=\sum_{i}\sum_{j}A_{ij} \]

    a=torch.arange(6).reshape(2,3)
    
    >>>tensor([[0, 1, 2],
               [3, 4, 5]])
    
    torch.einsum('ij->',[a])
    
    >>>tensor(15)
    
  • 列求和(列維度不變,行維度消失)

    \[b_j=\sum_iA_{ij} \]

    a=torch.arange(6).reshape(2,3)
    
    >>>tensor([[0, 1, 2],
               [3, 4, 5]])
    
    torch.einsum('ij->j',[a])
    
    >>>tensor([ 3.,  5.,  7.])
    
  • 列求和(列維度不變,行維度消失)

    \[b_i=\sum_jA_{ij} \]

    a=torch.arange(6).reshape(2,3)
    
    >>>tensor([[0, 1, 2],
               [3, 4, 5]])
    
    torch.einsum('ij->i', [a])
    
    >>>tensor([  3.,  12.])
    
  • 矩陣-向量相乘

    \[c_i=\sum_k A_{ik}b_k \]

    a=torch.arange(6).reshape(2,3)
    
    >>>tensor([[0, 1, 2],
               [3, 4, 5]])
    
    torch.einsum('ik,k->i',[a,b])
    
    >>>tensor([  5.,  14.])
    
  • 矩陣-矩陣乘法

    \[C_{ij}=\sum_{k}A_{ik}B_{kj} \]

    a=torch.arange(6).reshape(2,3)
    b=torch.arange(15).reshape(3,5)
    
    >>>tensor([[0, 1, 2],
               [3, 4, 5]])
               
    >>>tensor([[ 0,  1,  2,  3,  4],
               [ 5,  6,  7,  8,  9],
               [10, 11, 12, 13, 14]])
    
    torch.einsum('ik,kj->ij',[a,b])
    
    >>>tensor([[ 25,  28,  31,  34,  37],
               [ 70,  82,  94, 106, 118]])
    
  • 點積

    • 向量

      \[c=\sum_i a_i b_i \]

      a=torch.arange(3)
      b=torch.arange(3,6)
      
      >>>tensor([0, 1, 2])
      >>>tensor([3, 4, 5])
      
      torch.einsum('i,i->',[a,b])
      
      >>>tensor(14.)
      
    • 矩陣

      \[c=\sum_i\sum_j A_{ij}B_{ij} \]

      a=torch.arange(6).reshape(2,3)
      b=torch.arange(6,12).reshape(2,3)
      
      >>>tensor([[0, 1, 2],
                 [3, 4, 5]])
      
      >>>tensor([[ 6,  7,  8],
                 [ 9, 10, 11]])
      
      torch.einsum('ij,ij->',[a,b])
      
      >>>tensor(145.)
      
  • 外積

    \[C_{ij}=a_i b_j \]

    a=torch.arange(3)
    b=torch.arange(3,7)
    
    >>>tensor([0, 1, 2])
    >>>tensor([3, 4, 5, 6])
    
    torch.einsum('i,j->ij',[a,b])
    
    >>>tensor([[  0.,   0.,   0.,   0.],
               [  3.,   4.,   5.,   6.],
               [  6.,   8.,  10.,  12.]])
    
  • batch矩陣乘

    \[C_{ijl}=\sum_{k}A_{ijk}B_{ikl} \]

    a=torch.randn(3,2,5)
    b=torch.randn(3,5,3)
    
    >>>tensor([[[-1.4131e+00,  3.8372e-02,  1.2436e+00,  5.4757e-01,  2.9478e-01],
                [ 1.3314e+00,  4.4003e-01,  2.3410e-01, -5.3948e-01, -9.9714e-01]],
    
               [[-4.6552e-01,  5.4318e-01,  2.1284e+00,  9.5029e-01, -8.2193e-01],
                [ 7.0617e-01,  9.8252e-01, -1.4406e+00,  1.0071e+00,  5.9477e-01]],
    
              [[-1.0482e+00,  4.7110e-02,  1.0014e+00, -6.0593e-01, -3.2076e-01],
               [ 6.6210e-01,  3.7603e-01,  1.0198e+00,  4.6591e-01, -7.0637e-04]]])
    
    >>>tensor([[[-2.1797e-01,  3.1329e-04,  4.3139e-01],
                [-1.0621e+00, -6.0904e-01, -4.6225e-01],
                [ 8.5050e-01, -5.8867e-01,  4.8824e-01],
                [ 2.8561e-01,  2.6806e-01,  2.0534e+00],
                [-5.5719e-01, -3.3391e-01,  8.4069e-03]],
    
               [[ 5.2877e-01,  1.4361e+00, -6.4232e-01],
                [ 1.0813e+00,  8.5241e-01, -1.1759e+00],
                [ 4.9389e-01, -1.7523e-01, -9.5224e-01],
                [-1.3484e+00, -5.4685e-01,  8.5539e-01],
                [ 3.7036e-01,  3.4368e-01, -4.9617e-01]],
    
               [[-2.1564e+00,  3.0861e-01,  3.4261e-01],
                [-2.3679e+00, -2.5035e-01,  1.8104e-02],
                [ 1.1075e+00,  7.2465e-01, -2.0981e-01],
                [-6.5387e-01, -1.3914e-01,  1.5205e+00],
                [-1.6561e+00, -3.5294e-01,  1.9589e+00]]])
    
    torch.einsum('ijk,ikl->ijl',[a,b])
    
    >>>tensor([[[ 1.3170, -0.7075,  1.1067],
                [-0.1569, -0.2170, -0.6309]],
    
               [[-0.1935, -1.3806, -1.1458],
                [-0.4135,  1.7577,  0.3293]],
    
               [[ 4.1854,  0.5879, -2.1180],
                [-1.4922,  0.7846,  0.7267]]])
    
  • 張量縮約

    batch矩陣相乘是張量縮約的一個特例,比如有兩個張量,一個n階張量\(A\in \mathbb{R}^{I_1×l_2×...×I_n}​\),一個m階張量\(B\in \mathbb{R}^{J_1×J_2×...×J_m}​\)。取n=4,m=5,假定維度\(I_2=J_3​\)\(I_3=J_5​\),將這兩個張量在這兩個維度上(A張量的第2、3維度,B張量的第3、5維度)相乘,獲得新張量\(C\in \mathbb{R}^{I_1×I_4×J_1×J_2×J_4}​\),如下所示:

    \[C_{I_1×I_4×J_1×J_2×J_4}=\sum_{I_2==J_3}\sum_{I_3==J_5}A_{I_1×I_2×I_3×I_4}B_{J_1×J_2×J_3×J_4×J_5} \]

    a=torch.randn(2,3,5,7)
    b=torch.randn(11,13,3,17,5)
    
    torch.einsum('pqrs,tuqvr->pstuv', [a, b]).shape
    
    >>>torch.Size([2, 7, 11, 13, 17])
    
  • 多張量計算

    如前所述,einsum可用於超過兩個張量的計算,以雙線性變換為例:

    \[D_ij=\sum_k\sum_lA_{ik}B_{jkl}C_{il} \]

    a=torch.randn(2,3)
    b=torch.randn(5,3,7)
    c=torch.randn(2,7)
    
    torch.einsum('ik,jkl,il->ij',[a,b,c]).shape
    
    >>>torch.Size([2,5])
    

kimiyoung/transformer-xl的tf部分大量使用了einsum表達式。

einsum滿足你一切需要:深度學習中的愛因斯坦求和約定

向量點乘(內積)和叉乘(外積、向量積)概念及幾何意義解讀

矩陣外積與內積

外積-wiki


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM