基於numpy.einsum的張量網絡計算


張量與張量網絡

張量(Tensor)可以理解為廣義的矩陣,其主要特點在於將數字化的矩陣用圖形化的方式來表示,這就使得我們可以將一個大型的矩陣運算抽象化成一個具有良好性質的張量圖。由一個個張量所共同構成的運算網絡圖,就稱為張量網絡(Tensor Network)。讓我們用幾個常用的圖來看看張量網絡大概長什么樣子(下圖轉載自參考鏈接1):

上面這個圖從左到右分別表示:一階張量、二階張量以及三階張量,我們可以看出,一個張量的階數在圖像化的表示中被抽象稱為了張量的的數量,而中間的方形或者圓形則表示張量本身。實際上,一階張量代表的一個矢量,比如我們平時用python所定義的一個數組變量:

x = [1, 0]
y = [0, 1, 0]
z = [1, 2, 3, 4]

那么這里的x,y,z都是一階的張量。而二階張量所表示的含義是一個二維的矩陣,如我們常見的python多維數組:

M = [[1, -1], [-1, 1]]
N = [[1, 3], [2, 4], [5, 6]]

這里定義的M, N都是二階的張量。通過觀察這些示例中的一階和二階的張量我們可以得到一個規律:能夠用形如var[i]的形式讀取和遍歷var中的標量元素的就可以稱之為一階張量,能夠用形如var[i][j]的形式讀取和遍歷var中的標量元素的可以稱之為二階張量。顯然,屬於幾階的張量,跟張量內所包含的元素個數是無關的。那么根據這個客觀規律,我們可以再推廣到零階張量和更加高階的張量:

pi = 3.14
P = [[[1]]]
Q = [[[1, 1, 1], [1, 1, 1], [1, 1, 1]]]

在上述的python變量定義中,pi就是一個零階的張量,零階張量實際上就等同於一個標量,而P, Q都是三階的張量。需要注意的是,雖然張量P只有一個元素,但是如果我們需要讀取這個標量元素,我們必須使用如下的python指令來執行:

print (P[0][0][0])

因此P也是一個有三條腿的張量。在使用張量的形式來表示單個矩陣的同時,我們需要考慮如果有多個矩陣的乘法運算,我們該如何表示?我們先以兩種形式的python矩陣運算來說明張量計算的表示方法:

import numpy as np
M = np.random.rand(2, 2)
v = np.random.rand(2)
w = np.dot(M, v)
print (M)
print (v)
print (w)

這一串python代碼表示的計算過程為:\(w_{2\times1}=M_{2\times2}\cdot v_{2\times1}\),為了不失廣泛有效性,這里使用隨機的張量來進行計算,這里的M表示二階張量,v,w表示一階張量。如果從矩陣運算的角度來理解,實際上就是一個\(2\times2\)的矩陣乘以一個\(2\times1\)的矢量,並且得到了一個新的\(2\times1\)的矢量。計算所得到的結果如下所示:

[[0.09660039 0.55849787]
 [0.93007524 0.32329559]]
[0.74966152 0.59573188]
[0.40513259 0.88983912]

同時我們也考慮下另外一種張量運算的場景,一個高階的張量與另外一個高階的張量進行運算:

import numpy as np
A = np.random.rand(1, 2, 2, 2)
B = np.random.rand(2, 2, 2)
C = np.einsum('ijkl,klm->ijm', A, B)
print ('A:', A)
print ('B:', B)
print ('C:', C)

這一串python代碼表示的計算過程為:\(C_{1\times2\times2}=A_{1\times2\times2\times2}\cdot B_{2\times2\times2}\),由於這里的多維張量運算已經不能使用普通的numpy.dot來處理,因此我們還是適用了專業的張量計算函數numpy.einsum來進行處理,計算結果如下:

A: [[[[0.85939221 0.43684494]
   [0.71895754 0.31222944]]

  [[0.49276976 0.13639093]
   [0.04176578 0.14400289]]]]
B: [[[0.21157005 0.58052674]
  [0.59166167 0.21243451]]

 [[0.11420572 0.98995349]
  [0.1145634  0.97101076]]]
C: [[[0.5581652  1.60661377]
  [0.20621996 0.49621469]]]

以上的兩個案例,從張量理論的角度來理解,相當於分別將張量w和張量C表示成了多個張量組合運算的結果。由多個張量構成的組合運算,我們可以使用張量網絡來表示:

上圖所示的\((a)\)\((b)\)就分別表示張量w和張量C的張量網絡圖。而這個將張量網絡的所有張量進行計算,最終得到一個或一系列的新的張量的矩陣乘加過程,我們也稱之為張量縮並,英文叫Tensor Contraction,注:上圖轉載自參考鏈接1

張量縮並順序與計算復雜性

不失廣泛有效性的,我們可以以兩個張量的縮並案例來分析張量縮並的復雜性,兩個張量縮並的計算復雜性主要取決於這兩個張量總的的數量,如果兩個張量之間有共用的,則計為1。以上圖中的\((a)\)為例,一個\(2\times2\)的矩陣乘以一個\(2\times1\)的矢量,一共需要4次乘法運算,而由Mv所構成的張量網絡一共有2條腿,那么4次的乘法預算符合\(O(d^2)\)的計算復雜性,這里的d指的是指定的的維度,常用的是2。相關的復雜性除了理論推導,用numpy.einsum的功能模塊也可以實現程序判斷:

import numpy as np
M = np.random.rand(2, 2)
v = np.random.rand(2)
path_info = np.einsum_path('ij,j->i', M, v, optimize='greedy')
print(path_info[0])
print(path_info[1])

輸出結果如下:

['einsum_path', (0, 1)]
  Complete contraction:  ij,j->i
         Naive scaling:  2
     Optimized scaling:  2
      Naive FLOP count:  8.000e+00
  Optimized FLOP count:  9.000e+00
   Theoretical speedup:  0.889
  Largest intermediate:  2.000e+00 elements
--------------------------------------------------------------------------
scaling                  current                                remaining
--------------------------------------------------------------------------
   2                     j,ij->i                                     i->i

這里的scaling就是上面提到的復雜性\(O(d^2)\)中的\(2\)。同樣的如果以上圖中的\((b)\)為例,我們可以通過理論推導出其計算復雜性為\(O(d^5)\),即理論的scaling應該是5,下面也通過程序實現來給出定論:

import numpy as np
A = np.random.rand(1, 2, 2, 2)
B = np.random.rand(2, 2, 2)
path_info = np.einsum_path('ijkl,klm->ijm', A, B, optimize='greedy')
print(path_info[0])
print(path_info[1])

以上程序的執行結果如下:

['einsum_path', (0, 1)]
  Complete contraction:  ijkl,klm->ijm
         Naive scaling:  5
     Optimized scaling:  5
      Naive FLOP count:  3.200e+01
  Optimized FLOP count:  3.300e+01
   Theoretical speedup:  0.970
  Largest intermediate:  4.000e+00 elements
--------------------------------------------------------------------------
scaling                  current                                remaining
--------------------------------------------------------------------------
   5               klm,ijkl->ijm                                 ijm->ijm

這里需要我們注意的一點是,如果有兩條邊同時連接,那么計算scaling的時候也是作為兩條邊來計算的,而不是合並為一條邊之后再計算scaling。

由於上面所提到的兩個例子,其實都只涉及到兩個張量之間的預算,當多個張量一同進行運算時,就會引入一個新的參量:縮並順序,在張量網絡的實際應用場景中,縮並順序會極大程度上的影響張量網絡計算的速度。首先,讓我們用一個例子來分析,為什么不同的縮並順序會對張量網絡計算的性能產生影響:給定四個張量為: \(a_{ijk},b_{jlmn},c_{klo}和d_{mo}\) 。如果先縮並 bc ,則對應的計算復雜度的scaling為6。若按照縮並順序:cd->c,bc->b,ab->a,對應的計算復雜度scaling為5 。也就是說,從復雜度的角度來說,這里選出了一條復雜度較低的縮並路線,這一條復雜性scaling較好的縮並順序也是由numpy.einsum的貪心算法找出來的:

import numpy as np
np.random.seed(123)
a = np.random.rand(2, 2, 2)
b = np.random.rand(2, 2, 2, 2)
c = np.random.rand(2, 2, 2)
d = np.random.rand(2, 2)
path_info = np.einsum_path('ijk,jlmn,klo,mo->in', a, b, c, d, optimize='greedy')
print(path_info[0])
print(path_info[1])

執行的結果如下所示:

['einsum_path', (2, 3), (1, 2), (0, 1)]
  Complete contraction:  ijk,jlmn,klo,mo->in
         Naive scaling:  7
     Optimized scaling:  5
      Naive FLOP count:  5.120e+02
  Optimized FLOP count:  1.290e+02
   Theoretical speedup:  3.969
  Largest intermediate:  8.000e+00 elements
--------------------------------------------------------------------------
scaling                  current                                remaining
--------------------------------------------------------------------------
   4                 mo,klo->klm                         ijk,jlmn,klm->in
   5               klm,jlmn->jkn                              ijk,jkn->in
   4                 jkn,ijk->in                                   in->in

張量分割對張量網絡縮並復雜性的影響

在前面的章節中我們討論了將一個張量網絡縮並為一個張量的場景下,如何降低其復雜性scaling。其中重點說明了,在特定的縮並順序下,可以極大程度上的優化張量縮並的性能。這里我們討論一種在量子計算中常用的技巧:張量的分割。簡單的來說,就是前面提到的張量縮並的逆向過程,既然可以將兩個張量縮並成一個,那就有可能將一個張量分割成兩個張量。

那么為什么需要執行張量分割的操作呢?我們可以直接通過一個案例來說明:

import numpy as np
np.random.seed(123)
a = np.random.rand(2)
b = np.random.rand(2)
c = np.random.rand(2, 2, 2, 2)
d = np.random.rand(2)
e = np.random.rand(2)
path_info = np.einsum_path('i,j,ijkl,k,l', a, b, c, d, e, optimize='greedy')
print(path_info[0])
print(path_info[1])

這里給定了5個張量,其中張量c有四條腿,那么在這個場景下不論以什么樣的順序進行縮並,得到的復雜性的scaling都必然是4,以下是numpy.einsum給出的結果:

['einsum_path', (0, 2), (0, 3), (0, 2), (0, 1)]
  Complete contraction:  i,j,ijkl,k,l->
         Naive scaling:  4
     Optimized scaling:  4
      Naive FLOP count:  8.000e+01
  Optimized FLOP count:  6.100e+01
   Theoretical speedup:  1.311
  Largest intermediate:  8.000e+00 elements
--------------------------------------------------------------------------
scaling                  current                                remaining
--------------------------------------------------------------------------
   4                 ijkl,i->jkl                              j,k,l,jkl->
   3                   jkl,j->kl                                 k,l,kl->
   2                     kl,k->l                                    l,l->
   1                       l,l->                                       ->

但是,如果我們考慮先將這個腿最多的張量做一個分割,使其變為兩個三條腿的張量,並且這兩個張量之間通過一條邊進行連接,代碼示例如下:

import numpy as np
np.random.seed(123)
a = np.random.rand(2)
b = np.random.rand(2)
c = np.random.rand(2, 2, 2)
d = np.random.rand(2, 2, 2)
e = np.random.rand(2)
f = np.random.rand(2)
path_info = np.einsum_path('i,j,imk,jml,k,l', a, b, c, d, e, f, optimize='greedy')
print(path_info[0])
print(path_info[1])

讓我們先看看numpy.einsum是否會給出一樣的縮並順序呢?

['einsum_path', (0, 2), (0, 1), (0, 2), (0, 1), (0, 1)]
  Complete contraction:  i,j,imk,jml,k,l->
         Naive scaling:  5
     Optimized scaling:  3
      Naive FLOP count:  1.920e+02
  Optimized FLOP count:  5.300e+01
   Theoretical speedup:  3.623
  Largest intermediate:  4.000e+00 elements
--------------------------------------------------------------------------
scaling                  current                                remaining
--------------------------------------------------------------------------
   3                   imk,i->km                           j,jml,k,l,km->
   3                   jml,j->lm                              k,l,km,lm->
   2                     km,k->m                                 l,lm,m->
   2                     lm,l->m                                    m,m->
   1                       m,m->                                       ->

我們驚訝的發現,這個給定的scaling較低的縮並順序並沒有一開始就縮並m這條邊,如果先縮並了m這條邊,那么得到的結果應該跟上面未作分割的順序和scaling是一樣的。另言之,我們通過這種張量切割的方案,實際上大大降低了這個張量網絡的縮並所需時間。這里的復雜性scaling每降低1,就意味着需要執行的乘加次數有可能減少到優化前的\(\frac{1}{d}\).

補充測試

針對於上述章節提到的張量分割的方案,我們這里再多一組更加復雜一些的張量網絡的測試:

import networkx as nx
graph = nx.Graph()
graph.add_nodes_from([1,2,3,4,5,6,7,8,9])
graph.add_edges_from([(1,4),(2,4),(3,5),(4,5),(4,6),(5,6),(6,7),(5,8),(6,9)])
nx.draw_networkx(graph)


考慮上圖這樣的一個張量網絡,我們也先將其中三個四條腿的張量進行分割,分割后的張量網絡變為如下所示的拓撲結構:

import networkx as nx
graph = nx.Graph()
graph.add_nodes_from([1,2,3,4,5,6,7,8,9,10,11,12])
graph.add_edges_from([(1,4),(5,4),(2,5),(4,5),(4,8),(5,6),(6,7),(3,7),(7,9),(6,11),(8,10),(8,9),(9,12)])
nx.draw_networkx(graph)


然后再次使用numpy.einsum來進行驗證。首先是張量分割前的張量網絡縮並:

import numpy as np
np.random.seed(123)
a = np.random.rand(2)
b = np.random.rand(2)
c = np.random.rand(2)
d = np.random.rand(2, 2, 2, 2)
e = np.random.rand(2, 2, 2, 2)
f = np.random.rand(2, 2, 2, 2)
g = np.random.rand(2)
h = np.random.rand(2)
i = np.random.rand(2)
path_info = np.einsum_path('i,j,k,ijlm,mnko,lpoq,p,n,q', a, b, c, d, e, f, g, h, i, optimize='greedy')
print(path_info[0])
print(path_info[1])

執行結果如下:

['einsum_path', (0, 3), (1, 2), (1, 2), (0, 3), (0, 2), (0, 1), (0, 1), (0, 1)]
  Complete contraction:  i,j,k,ijlm,mnko,lpoq,p,n,q->
         Naive scaling:  9
     Optimized scaling:  4
      Naive FLOP count:  4.608e+03
  Optimized FLOP count:  1.690e+02
   Theoretical speedup:  27.266
  Largest intermediate:  8.000e+00 elements
--------------------------------------------------------------------------
scaling                  current                                remaining
--------------------------------------------------------------------------
   4                 ijlm,i->jlm                j,k,mnko,lpoq,p,n,q,jlm->
   4                 mnko,k->mno                   j,lpoq,p,n,q,jlm,mno->
   4                 p,lpoq->loq                      j,n,q,jlm,mno,loq->
   3                   jlm,j->lm                         n,q,mno,loq,lm->
   3                   mno,n->mo                            q,loq,lm,mo->
   3                   loq,q->lo                               lm,mo,lo->
   3                   mo,lm->lo                                  lo,lo->
   2                     lo,lo->                                       ->

我們可以看到未進行張量分割前的復雜性scaling為4,再讓我們看下張量分割之后的張量網絡縮並情況:

import numpy as np
np.random.seed(123)
a = np.random.rand(2)
b = np.random.rand(2)
c = np.random.rand(2)
d = np.random.rand(2, 2, 2)
e = np.random.rand(2, 2, 2)
f = np.random.rand(2, 2, 2)
g = np.random.rand(2, 2, 2)
h = np.random.rand(2, 2, 2)
i = np.random.rand(2, 2, 2)
j = np.random.rand(2)
k = np.random.rand(2)
l = np.random.rand(2)
path_info = np.einsum_path('i,j,k,iml,jmn,nop,kpq,lrs,sqt,r,o,t', a, b, c, d, e, f, g, h, i, j, k, l, optimize='greedy')
print(path_info[0])
print(path_info[1])

執行結果如下:

['einsum_path', (0, 3), (0, 2), (0, 2), (0, 4), (0, 2), (0, 1), (0, 1), (0, 1), (0, 1), (0, 1), (0, 1)]
  Complete contraction:  i,j,k,iml,jmn,nop,kpq,lrs,sqt,r,o,t->
         Naive scaling:  12
     Optimized scaling:  3
      Naive FLOP count:  4.915e+04
  Optimized FLOP count:  1.690e+02
   Theoretical speedup:  290.840
  Largest intermediate:  4.000e+00 elements
--------------------------------------------------------------------------
scaling                  current                                remaining
--------------------------------------------------------------------------
   3                   iml,i->lm       j,k,jmn,nop,kpq,lrs,sqt,r,o,t,lm->
   3                   jmn,j->mn          k,nop,kpq,lrs,sqt,r,o,t,lm,mn->
   3                   kpq,k->pq             nop,lrs,sqt,r,o,t,lm,mn,pq->
   3                   o,nop->np                lrs,sqt,r,t,lm,mn,pq,np->
   3                   r,lrs->ls                   sqt,t,lm,mn,pq,np,ls->
   3                   t,sqt->qs                      lm,mn,pq,np,ls,qs->
   3                   mn,lm->ln                         pq,np,ls,qs,ln->
   3                   np,pq->nq                            ls,qs,ln,nq->
   3                   qs,ls->lq                               ln,nq,lq->
   3                   nq,ln->lq                                  lq,lq->
   2                     lq,lq->                                       ->

我們再次發現,張量縮並的復雜性scaling被優化到了3。假如是我們常見的\(d=2\)的張量網絡,那么在進行張量分割之后,類似於上面這個案例的,張量縮並的時間可以加速1倍甚至更多。

總結概要

本文主要介紹了張量網絡的基本定義及其縮並復雜性scaling的含義,其中利用numpy.einsum這個高級輪子進行了用例的演示,並且額外的介紹了張量分割在張量網絡縮並實際應用場景中的重要地位。通常我們會配合GPU來進行張量網絡的縮並,那么這個時候縮並復雜性的scaling影響的就不僅僅是縮並的速度,因為GPU本身的內存是比較局限的,而不斷的IO會進一步拉長張量網絡縮並所需要的時間。而如果能夠有方案將一個給定的張量網絡的復雜性scaling降低到GPU自身內存可以存儲的水平,那將極大程度上的降低使用張量網絡技術求解實際問題的時間。

版權聲明

本文首發鏈接為:https://www.cnblogs.com/dechinphy/p/tensor.html
作者ID:DechinPhy
更多原著文章請參考:https://www.cnblogs.com/dechinphy/

參考鏈接

  1. 什么是張量網絡(tensor network)? - 何史提的回答 - 知乎 https://www.zhihu.com/question/54786880/answer/147099121
  2. Michael Streif1, Martin Leib,"Training the Quantum Approximate Optimization Algorithm without access to a Quantum Processing Unit", 2019, arXiv:1908.08862


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM