技術背景
一般認為Jax是谷歌為了取代TensorFlow而推出的一款全新的端到端可微的框架,但是Jax同時也集成了絕大部分的numpy函數,這就使得我們可以更加簡便的從numpy的計算習慣中切換到GPU的計算中。Jax除了支持GPU的張量運算,更重要的一個方面是Jax還支持谷歌自己的硬件TPU的張量運算。關於張量計算,可以參考前面寫過的這一篇博客。
而標題中的另外一個概念:Hamming Distance是用來衡量兩個字符串之間的相似關系評分算法,如果兩個字符串的所有元素完全相同,那么就會得到一個0的分數,如果兩個長度各為100的字符串完全不相同(即每一個位置的字符都完全不同),那么得到的Hamming Distance就是100。而關於Normalized Hamming Distance的概念,則是為了使得結果更加的收斂,因此在Hamming Distance的基礎之上再除以字符串的總長度,得到一個新的評分。舉個例子說,Boy
和Bob
這兩個字符串的Hamming Distance為1,而Normalized Hamming Distance為\(\frac{1}{3}\)。
Numpy和Jax代碼實現
一般計算Hamming Distance可以通過scipy中自帶的distance.hamming
來計算兩個字符串之間的相似度,然而我們在日常的計算中更多的會把字符串轉化成一個用數字來表示的數組,因此這里我們可以直接使用numpy的equal
函數之后在做一個sum
即可得到我們需要的Hamming Distance,如果再除以一個數組長度,那么就是Normalized Hamming Distance。由於Jax上實現了GPU版本的Numpy的函數,因此這里我們將Numpy的函數和Jax的函數寫到一起來進行對比,尤其是時間上的一個衡量。這里測試的邏輯是:我們先通過Numpy來生成兩個給定維度的隨機數,然后將其轉化成兩個Jax格式的數組,然后分別對這兩組不同格式的數組分別用Numpy和Jax計算Hamming Distance,最終統計多次運行所得到的時間。
# normalized_hamming_distance.py
import numpy as np
import jax.numpy as jnp
import time
if __name__ == '__main__':
np.random.seed(1)
length = 100000000
arr1 = np.random.randint(5, size=(length,),dtype=np.int32)
arr2 = np.random.randint(5, size=(length,),dtype=np.int32)
arr1_jax = jnp.array(arr1)
arr2_jax = jnp.array(arr2)
# Start Testing
time0 = time.time()
for _ in range(10):
nhd = np.sum(np.equal(arr1,arr2))/length
time1 = time.time()
for _ in range(10):
nhd_jax = jnp.sum(jnp.equal(arr1_jax,arr2_jax))/length
time2 = time.time()
# Result analysis
print ('The normalized hamming distance by numpy is: {}'.format(nhd))
print ('The normalized hamming distance by jax is: {}'.format(nhd_jax))
print ('The time cost by numpy is: {}s'.format(time1-time0))
print ('The time cost by jax is: {}s'.format(time2-time1))
輸出結果如下所示:
The normalized hamming distance by numpy is: 0.20006858
The normalized hamming distance by jax is: 0.20006857812404633
The time cost by numpy is: 1.7030510902404785s
The time cost by jax is: 0.28351473808288574s
經過對比,我們發現Jax所實現的Numpy的GPU版本,可以在幾乎不用改動接口的條件下,極大程度上的加速了Numpy的計算過程。
總結概要
本文通過對比Jax和Numpy計算Normalized Hamming Distance的過程來對比了Jax所實現的Numpy的GPU版本所帶來的加速效果。實際上在維度比較小的時候,Numpy還是有非常輕量級的優勢,此時GPU的加速效果並沒有很好的體現出來。但是在規模較大的輸入場景下,GPU的並行加速效果簡直無敵,而且幾乎沒有改動原本Numpy的函數接口。除此之外,Jax作為一個函數式編程的端到端可微編程框架,支持jit、vmap、pmap和xmap等非常神奇的加速和並行化功能,為深度學習等領域提供了非常強有力的支持。
版權聲明
本文首發鏈接為:https://www.cnblogs.com/dechinphy/p/jax-numpy.html
作者ID:DechinPhy
更多原著文章請參考:https://www.cnblogs.com/dechinphy/
打賞專用鏈接:https://www.cnblogs.com/dechinphy/gallery/image/379634.html
騰訊雲專欄同步:https://cloud.tencent.com/developer/column/91958