einsum 入門
愛因斯坦求和(einsum),計算指定維度元素的乘積,再求和。可以將愛因斯坦求和記法理解為懶人求和記法,省略不寫 \(\Sigma\)。einsum 具有很強大的表達能力,使用 einsum 的表達式,可以表達各種各樣的操作,比如矩陣乘法、矩陣轉置。很多庫都支持 einsum,有 numpy,pytorch,tensorflow。
這篇文章分為三個部分。
- 第一部分,從代碼入手,直接上手 einsum,畢竟 einsum 是很直觀的,聰明的你看一看代碼就能理解。
- 第二部分,介紹愛因斯坦求和(einsum),介紹表達式各個部分的含義。
- 第三部分,給出幾道練習題,有興趣的讀者可以思考一下如何用 einsum 表達。
einsum 快速上手
下面以 pytorch 為例子,介紹 einsum 的一些常見的用法。
例1:矩陣乘法
矩陣乘法中,每個元素由下面的表達式確定。
第一個學習的 einsum 表達式是,ik,kj->ij
。前面提到過,愛因斯坦求和記法可以理解為懶人求和記法。將上述公式中的 \(\Sigma\) 去掉,並且將左右兩邊對調一下,省去矩陣之后,剩下的就是 ik,kj->ij
了。
例2:對角線元素
雖然 einsum 的名字當中有 sum,但是 einsum 可以不做求和。舉個例子,獲取二維方陣的對角線元素,結果放入一維向量。
上面,A 是一維向量,B 是二維方陣。使用 einsum 記法,可以寫作 ii->i
例3:跡(trace)
觀察一下,矩陣乘法和對角線元素兩個表達式有什么區別。ik,kj->ij
和 ii->i
。
- 矩陣乘法中,箭頭左邊有
k
而箭頭右邊沒有。 - 對角線元素中,左邊和右邊都只有
i
。 - 矩陣乘法省略了 \(\Sigma\),對角線元素沒有省略 \(\Sigma\)
基於上面的觀察,可以大致推測出來,左邊出現但是右邊沒有出現的符號,這個符號省略了 \(\Sigma\)。
接下來,我們來嘗試一下求解矩陣的跡(trace),即對角線元素的和。
t 是常量,A 是二維方陣。按照前面的做法,省略 \(\Sigma\),左右兩邊對調,省去矩陣和 t,剩下的就是 kk->
。
對,你沒有看錯,右邊沒有東西。這意味着,左邊的符號都省略了 \(\Sigma\)。
例4:矩陣轉置
有了前面的基礎,矩陣轉置就很簡單啦,寫表達式。
A 和 B 都是二維方陣。einsum 可以表達為 ij->ji
。
在 pytorch 中,還支持省略前面的維度。比如,只轉置最后兩個維度,可以表達為 ...ij->...ji
。下面展示了一個含有四個二維矩陣的三維矩陣,轉置三維矩陣中的每個二維矩陣。
einsum 表達式
在快速上手了 einsum 之后,接下來考察一下 einsum,補充一些細節。
pytorch 的文檔寫得非常清楚了:https://pytorch.org/docs/stable/generated/torch.einsum.html
下面總結一下規則:
- 表達式由輸入和輸出兩部分組成。例子,
ij->ji
- 輸出可以省略,箭頭也可以省略。輸入中僅出現一次的字符將按照字母序構成輸出。例子,
ba
完整的表達式是ba->ab
- 輸入中多次出現的字符,將被用作求和。例子,
kj,ji
完整的表達式是kj,ji->ik
,矩陣乘法再相乘。 - 輸出可以指定,但是輸出中的每個字符必須在輸入中出現至少一次,輸出的每個字符在輸出中只能出現最多一次。例子,
ab->aa
是非法的,ab->c
是非法的,ab->a
是合法的。 - 省略符
...
是用來跳過部分維度。例子,...ij,...jk
表示 batch 矩陣乘法。 - 在輸出沒有指定的情況下,省略符優先級高於普通字符。例子,
b...a
完整的表達式是b...a->...ab
,可以將一個形狀為(a,b,c)
的矩陣變為形狀為(b,c,a)
的矩陣。 - 允許多個矩陣輸入,表達式中使用逗號分開不同矩陣輸入的下標。例子,
i,i,i
表示將三個一維向量按位相乘,並相加。 - 除了箭頭,其他任何地方都可以加空格。例子,
i j , j k -> ik
是合法的,ij,jk - > ik
是非法的。 - 輸入的表達式,維度需要和輸入的矩陣對上,不能多也不能少。比如一個 shape 為
(4,3,3)
的矩陣,表達式ab->a
是非法的,abc->
是合法的。
練習
練習1:向量內積
練習2:向量外積
練習3:矩陣按行求和
練習4:矩陣按列求和
練習5:轉置最后兩維
總結
上面簡單介紹了 einsum 的用法,einsum 可以表達很多操作,比如矩陣乘法、矩陣轉置。上面僅僅是做了簡單的介紹,大家在煉丹的過程中,遇到過哪些讓你直呼牛逼的 einsum 表達式呢?希望能得到你的分享,感謝!