技術背景
隨機采樣問題,不僅僅只是一個統計學/離散數學上的概念,其實在工業領域也都有非常重要的應用價值/潛在應用價值,具體應用場景我們這里就不做贅述。本文重點在於在不同平台上的采樣速率,至於另外一個重要的參數檢驗速率,這里我們先不做評估。因為在Jax中直接支持vmap的操作,而numpy的原生函數大多也支持了向量化的運算,兩者更像是同一種算法的不同實現。所以對於檢驗的場景,兩者的速度區別更多的也是在硬件平台上。
隨機采樣示例
關於Jax的安裝和基本使用方法,讀者可以自行參考Jax的官方文檔,需要注意的是,Jax有CPU、GPU和TPU三個版本,如果需要使用其GPU版本的功能,還需要依賴於jaxlib,另外最好是指定安裝對應的CUDA版本,這都是安裝過程中所踩過的一些坑。最后如果安裝的不是GPU的版本,運行Jax腳本的時候會有相關的提示說明。
隨機采樣,可以是針對一個給定的連續函數,也可以針對一個離散化的列表,但是為了更好的擴展性,一般問題都會轉化成先獲取均勻的隨機分布,再轉化成其他函數形式的分布,如正態分布等。所以這里我們更加的是關注下均勻分布函數的效率:
import numpy as np
import time
import jax.random as random
key = random.PRNGKey(0)
print ('An small example of numpy sampler: \n{}'.format(np.random.uniform(low=0,high=1,size=5)))
print ('An small example of jax sampler: \n{}'.format(random.uniform(key,shape=(5,),minval=0, maxval=1)))
data_size = 400000000
time0 = time.time()
s = np.random.uniform(low=0,high=1,size=data_size)
print ('The numpy time cost is: {}s'.format(time.time()-time0))
time1 = time.time()
v = random.uniform(key,shape=(data_size,),minval=0, maxval=1)
print ('The jax time cost is: {}s'.format(time.time()-time1))
執行結果如下:
An small example of numpy sampler:
[0.33654613 0.20267496 0.86859762 0.14940831 0.30321738]
An small example of jax sampler:
[0.57450044 0.09968603 0.39316022 0.8941783 0.59656656]
The numpy time cost is: 3.6664984226226807s
The jax time cost is: 0.10985755920410156s
同樣是在生成雙精度浮點數的情況下,我們可預期的GPU的速率在數據長度足夠大的情況下一定是會更快的,這個運算結果也佐證了這個說法。
總結概要
關於工業領域中可能使用到的隨機采樣,更多的是這樣的一個場景:給定一個連續或者離散的分布,然后進行大規模的連續采樣,采樣的同時需要對每一個得到的樣點進行分析打分,最終在這大規模的采樣過程中,有可能被使用到的樣品可能只有其中的幾份。那么這樣的一個抽象問題,就非常適合使用分布式的多GPU硬件架構來實現。
版權聲明
本文首發鏈接為:https://www.cnblogs.com/dechinphy/p/sampler.html
作者ID:DechinPhy
更多原著文章請參考:https://www.cnblogs.com/dechinphy/
打賞專用鏈接:https://www.cnblogs.com/dechinphy/gallery/image/379634.html