Gumbel-Softmax Trick和Gumbel分布


  之前看MADDPG論文的時候,作者提到在離散的信息交流環境中,使用了Gumbel-Softmax estimator。於是去搜了一下,發現該技巧應用甚廣,如深度學習中的各種GAN、強化學習中的A2C和MADDPG算法等等。只要涉及在離散分布上運用重參數技巧時(re-parameterization),都可以試試Gumbel-Softmax Trick。

  這篇文章是學習以下鏈接之后的個人理解,內容也基本出於此,需要深入理解的可以自取。


  這篇文章從直觀感覺講起,先講Gumbel-Softmax Trick用在哪里及如何運用,再編程感受Gumbel分布的效果,最后討論數學證明。

目錄

一、Gumbel-Softmax Trick用在哪里

問題來源

  通常在強化學習中,如果動作空間是離散的,比如上、下、左、右四個動作,通常的做法是網絡輸出一個四維的one-hot向量(不考慮空動作),分別代表四個動作。比如[1,0,0,0]代表上,[0,1,0,0]代表下等等。而具體取哪個動作呢,就根據輸出的每個維度的大小,選擇值最大的作為輸出動作,即\(\arg\max(v)\)

  例如網絡輸出的四維向量為\(v=[-20,10,9.6,6.2]\),第二個維度取到最大值10,那么輸出的動作就是[0,1,0,0],也就是下,這和多類別的分類任務是一個道理。但是這種取法有個問題是不能計算梯度,也就不能更新網絡。通常的做法是加softmax函數,把向量歸一化,這樣既能計算梯度,同時值的大小還能表示概率的含義。softmax函數定義如下:

\[\sigma(z_i)=\frac{e^{z_i}}{\sum\limits_{j=1}^Ke^{z_j}} \]

  那么將\(v=[-20,10,9.6,6.2]\)通過softmax函數后有\(\sigma(v)=[0,0.591,0.396,0.013]\),這樣做不會改變動作或者說類別的選取,同時softmax傾向於讓最大值的概率顯著大於其他值,比如這里10和9.6經過softmax放縮之后變成了0.591和0.396,6.2對應的概率更是變成了0.013,這有利於把網絡訓成一個one-hot輸出的形式,這種方式在分類問題中是常用方法。

  但是這么做還有一個問題,這個表示概率的向量\(\sigma(v)=[0,0.591,0.396,0.013]\)並沒有真正顯示出概率的含義,因為一旦某個值最大,就選擇相應的動作或者分類。比如\(\sigma(v)=[0,0.591,0.396,0.013]\)\(\sigma(v)=[0,0.9,0.1,0]\)在類別選取的結果看來沒有任何差別,都是選擇第二個類別,但是從概率意義上講差別是巨大的。所以需要一種方法不僅選出動作,而且遵從概率的含義。

  很直接的方法是依概率采樣就完事了,比如直接用np.random.choice函數依照概率生成樣本值,這樣概率就有意義了。這樣做確實可以,但是又有一個問題冒了出來:這種方式怎么計算梯度?不能計算梯度怎么用BP的方式更新網絡?

  這時重參數(re-parameterization)技巧解決了這個問題,這里有詳盡的解釋,不過比較晦澀。簡單來說重參數技巧的一個用處是把采樣的步驟移出計算圖,這樣整個圖就可以計算梯度BP更新了。之前我一直在想分類任務直接softmax之后BP更新不就完事了嗎,為什么非得采樣。后來看了VAE和GAN之后明白,還有很多需要采樣訓練的任務。這里舉簡單的VAE(變分自編碼器)的例子說明需要采樣訓練的任務以及重參數技巧,詳細內容來自視頻博客

Re-parameterization Trick

  最原始的自編碼器通常長這樣:



  左右兩邊是端到端的出入輸出網絡,中間的綠色是提取的特征向量,這是一種直接從圖片提取特征的方式。
  而VAE長這樣:



  VAE的想法是不直接用網絡去提取特征向量,而是提取這張圖像的分布特征,也就把綠色的特征向量替換為分布的參數向量,比如說均值和標准差。然后需要decode圖像的時候,就從encode出來的分布中采樣得到特征向量樣本,用這個樣本去重建圖像,這時怎么計算梯度的問題就出現了。
  重參數技巧可以解決這個問題,它長下面這樣:



  假設圖中的\(x\)\(\phi\)表示VAE中的均值和標准差向量,它們是確定性的節點。而需要輸出的樣本\(z\)是帶有隨機性的節點,重參數就是把帶有隨機性的\(z\)變成確定性的節點,同時隨機性用另一個輸入節點\(\epsilon\)代替。例如,這里用正態分布采樣,原本從均值為\(x\)和標准差為\(\phi\)的正態分布\(N(x,\phi^2)\)中采樣得到\(z\)。將其轉化成從標准正態分布\(N(0,1)\)中采樣得到\(\epsilon\),再計算得到\(z=x+\epsilon\cdot \phi\)。這樣一來,采樣的過程移出了計算圖,整張計算圖就可以計算梯度進行更新了,而新加的\(\epsilon\)的輸入分支不做更新,只當成一個沒有權重變化的輸入。

  到這里,需要采樣訓練的任務實例以及重參數技巧基本有個概念了。

Gumbel-Softmax Trick

  VAE的例子是一個連續分布(正態分布)的重參數,離散分布的情況也一樣,首先需要可以采樣,使得離散的概率分布有意義而不是只取概率最大的值,其次需要可以計算梯度。那么怎么做到的,具體操作如下:

  對於\(n\)維概率向量\(\pi\),對\(\pi\)對應的離散隨機變量\(x_{\pi}\)添加Gumbel噪聲,再取樣

\[x_{\pi}=\arg\max(\log(\pi_i)+G_i) \]

  其中,\(G_i\)是獨立同分布的標准Gumbel分布的隨機變量,標准Gumbel分布的CDF為\(F(x)=e^{-e^{-x}}\)
  這就是Gumbel-Max trick。可以看到由於這中間有一個\(\arg\max\)操作,這是不可導的,所以用softmax函數代替之,也就是Gumbel-Softmax Trick,而\(G_i\)可以通過Gumbel分布求逆從均勻分布生成,即\(G_i=-\log(-\log(U_i)),U_i\sim U(0,1)\),這樣就搞定了。

  具體實踐是這樣操作的,

  • 對於網絡輸出的一個\(n\)維向量\(v\),生成\(n\)個服從均勻分布\(U(0,1)\)的獨立樣本\(\epsilon_1,...,\epsilon_n\)
  • 通過\(G_i=-\log(-\log(\epsilon_i))\)計算得到\(G_i\)
  • 對應相加得到新的值向量\(v'=[v_1+G_1,v_2+G_2,...,v_n+G_n]\)
  • 通過softmax函數

\[\sigma_{\tau}(v'_i)=\frac{e^{v'_i/\tau}}{\sum\limits_{j=1}^ne^{v'_j/\tau}} \]

  計算概率大小得到最終的類別。其中\(\tau\)是溫度參數。

  直觀上感覺,對於強化學習來說,在選擇動作之前加一個擾動,相當於增加探索度,感覺上是合理的。對於深度學習的任務來說,添加隨機性去模擬分布的樣本生成,也是合情合理的。

二、Gumbel分布采樣效果

  為什么使用Gumbel分布生成隨機數,就能模擬離散概率分布的樣本呢?這部分使用代碼模擬來感受它的優越性。這部分例子和代碼來自這里

  首先Gumbel分布的概率密度函數長這樣:

\[p(x)=\frac{1}{\beta}e^{-z-e^{-z}} \]

  其中\(z=\frac{x-\mu}{\beta}\)

  Gumbel分布是一類極值分布,那么它表示什么含義呢?原鏈接舉了一個ice cream的例子,沒有get到點。這里舉一個類似的喝水的例子。
  比如你每天都會喝很多次水(比如100次),每次喝水的量也不一樣。假設每次喝水的量服從正態分布\(N(\mu,\sigma^2)\)(其實也有點不合理,畢竟喝水的多少不能取為負值,不過無傷大雅能理解就好,假設均值為5),那么每天100次喝水里總會有一個最大值,這個最大值服從的分布就是Gumbel分布。實際上,只要是指數族分布,它的極值分布都服從Gumbel分布。那么上面這個例子的分布長什么樣子呢,作圖有

from scipy.optimize import curve_fit
import numpy as np
import matplotlib.pyplot as plt
mean_hunger = 5
samples_per_day = 100
n_days = 10000
samples = np.random.normal(loc=mean_hunger, size=(n_days, samples_per_day))
daily_maxes = np.max(samples, axis=1)

def gumbel_pdf(prob,loc,scale):
    z = (prob-loc)/scale
    return np.exp(-z-np.exp(-z))/scale

def plot_maxes(daily_maxes):
    probs,hungers,_=plt.hist(daily_maxes,density=True,bins=100)
    plt.xlabel('Volume')
    plt.ylabel('Probability of Volume being daily maximum')
    (loc,scale),_=curve_fit(gumbel_pdf,hungers[:-1],probs)
    #curve_fit用於曲線擬合
    #接受需要擬合的函數(函數的第一個參數是輸入,后面的是要擬合的函數的參數)、輸入數據、輸出數據
    #返回的是函數需要擬合的參數
    # https://blog.csdn.net/guduruyu/article/details/70313176
    plt.plot(hungers,gumbel_pdf(hungers,loc,scale))
    
plt.figure()
plot_maxes(daily_maxes)



  那么gumbel分布在離散分布的采樣中效果如何呢?可以作圖比較一下。先定義一個多項分布,作出真實的概率密度圖。再通過采樣的方式比較各種方法的效果。

  如下代碼定義了一個7類別的多項分布,其真實的密度函數如下圖

n_cats = 7
cats = np.arange(n_cats)
probs = np.random.randint(low=1, high=20, size=n_cats)
probs = probs / sum(probs)
logits = np.log(probs)
def plot_probs():
    plt.bar(cats, probs)
    plt.xlabel("Category")
    plt.ylabel("Probability")
plt.figure()
plot_probs()



  首先我們直接根據真實的分布利用np.random.choice函數采樣對比效果

n_samples = 1000
def plot_estimated_probs(samples,ylabel=''):
    n_cats = np.max(samples)+1
    estd_probs,_,_ = plt.hist(samples,bins=np.arange(n_cats+1),align='left',edgecolor='white',density=True)
    plt.xlabel('Category')
    plt.ylabel(ylabel+'Estimated probability')
    return estd_probs
def print_probs(probs):
    print('  '.join(['{:.2f}']`len(probs)).format(`probs))

samples = np.random.choice(cats,p=probs,size=n_samples) 

plt.figure()
plt.subplot(1,2,1)
plot_probs()
plt.subplot(1,2,2)
estd_probs = plot_estimated_probs(samples)
plt.tight_layout()#緊湊顯示圖片

print('Original probabilities:\t\t',end='')
print_probs(probs)
print('Estimated probabilities:\t',end='')
print_probs(estd_probs)



Original probabilities:  0.11 0.05 0.12 0.21 0.12 0.26 0.14
Estimated probabilities: 0.12 0.04 0.12 0.23 0.10 0.26 0.13

  效果意料之中的好。可以想到要是沒有不能求梯度這個問題,直接從原分布采樣是再好不過的。

  接着通過前述的方法添加Gumbel噪聲采樣,同時也添加正態分布和均勻分布的噪聲作對比

def sample_gumbel(logits):
    noise = np.random.gumbel(size=len(logits))
    sample = np.argmax(logits+noise)
    return sample
gumbel_samples = [sample_gumbel(logits) for _ in range(n_samples)]

def sample_uniform(logits):
    noise = np.random.uniform(size=len(logits))
    sample = np.argmax(logits+noise)
    return sample
uniform_samples = [sample_uniform(logits) for _ in range(n_samples)]

def sample_normal(logits):
    noise = np.random.normal(size=len(logits))
    sample = np.argmax(logits+noise)
    return sample
normal_samples = [sample_normal(logits) for _ in range(n_samples)]

plt.figure(figsize=(10,4))
plt.subplot(1,4,1)
plot_probs()
plt.subplot(1,4,2)
gumbel_estd_probs = plot_estimated_probs(gumbel_samples,'Gumbel ')
plt.subplot(1,4,3)
normal_estd_probs = plot_estimated_probs(normal_samples,'Normal ')
plt.subplot(1,4,4)
uniform_estd_probs = plot_estimated_probs(uniform_samples,'Uniform ')
plt.tight_layout()

print('Original probabilities:\t\t',end='')
print_probs(probs)
print('Gumbel Estimated probabilities:\t',end='')
print_probs(gumbel_estd_probs)
print('Normal Estimated probabilities:\t',end='')
print_probs(normal_estd_probs)
print('Uniform Estimated probabilities:',end='')
print_probs(uniform_estd_probs)



Original probabilities:      0.11 0.05 0.12 0.21 0.12 0.26 0.14
Gumbel Estimated probabilities: 0.11 0.04 0.11 0.23 0.12 0.26 0.14
Normal Estimated probabilities:  0.08 0.02 0.11 0.26 0.11 0.29 0.12
Uniform Estimated probabilities: 0.00 0.00 0.00 0.32 0.01 0.63 0.03

  可以明顯看到Gumbel噪聲的采樣效果是最好的,正態分布其次,均勻分布最差。也就是說可以用Gumbel分布做Re-parameterization使得整個圖計算可導,同時樣本點最接近真實分布的樣本。

三、數學證明

  為什么添加Gumbel噪聲有如此效果,下面闡述問題並給出證明。

  假設有一個\(K\)維的輸出向量,每個維度的值記為\(x_k\),通過softmax函數可得,取到每個維度的概率為:

\[\pi_k=\frac{e^{x_k}}{\sum^K_{k'=1}e^{x'_k}} \]

  這是直接softmax得到的概率密度函數,如果換一種方式,對每個\(x_k\)添加獨立的標准Gumbel分布(尺度參數為1,位置參數為0)噪聲,並選擇值最大的維度作為輸出,得到的概率密度同樣為\(\pi_k\)

  下面給出Gumbel分布的概率密度函數和分布函數,並證明這件事情。

  尺度參數為1,位置參數為\(\mu\)的Gumbel分布的PDF為

\[f(z;\mu)=e^{-(z-\mu)-e^{-(z-\mu)}} \]

  CDF為

\[F(z;\mu)=e^{-e^{-(z-\mu)}} \]

  假設第\(k\)個Gumbel分布對應\(x_k\),加和得到隨機變量\(z_k=x_k+G_k\),即相當於\(z_k\)服從尺度參數為1,位置參數為\(\mu=x_k\)的Gumbel分布。要證明這樣取得的隨機變量\(z_k\)與原隨機變量相同,只需證明取到\(z_k\)的概率為\(\pi_k\)。也就是\(z_k\)比其他所有\(z_{k'}(k'\not=k)\)大的概率為\(\pi_k\),即

\[P(z_k\ge z_{k'};\forall k'\not = k|\{x_{k'}\}_{k'=1}^K)=\pi_k \]

  關於\(z_k\)的條件累積概率分布函數為

\[P(z_k\ge z_{k'};\forall k'\not = k|z_k,\{x_{k'}\}_{k'=1}^K)=P(z_1\le z_k)P(z_2\le z_k)\cdot\cdot\cdot P(z_{k-1}\le z_{k})P(z_{k+1}\le z_{k})\cdot\cdot\cdot P(z_K\le z_k) \]

  即

\[P(z_k\ge z_{k'};\forall k'\not = k|z_k,\{x_{k'}\}_{k'=1}^K)=\prod\limits_{k'\not= k}e^{-e^{-(z_k-x_{k'})}} \]

  對\(z_k\)求積分可得邊緣累積概率分布函數

\[P(z_k\ge z_{k'};\forall k'\not = k|\{x_{k'}\}_{k'=1}^K)=\int P(z_k\ge z_{k'};\forall k'\not = k|z_k,\{x_{k'}\}_{k'=1}^K)\cdot f(z_k;x_k)\,dz_k \]

  帶入式子有

\[P(z_k\ge z_{k'};\forall k'\not = k|\{x_{k'}\}_{k'=1}^K)=\int \prod\limits_{k'\not= k}e^{-e^{-(z_k-x_{k'})}}\cdot e^{-(z_k-x_k)-e^{-(z_k-x_k)}}\,dz_k \]

  化簡有

\[\begin{array}{l} P(z_k\ge z_{k'};\forall k'\not = k|\{x_{k'}\}_{k'=1}^K)\\ \qquad \qquad =\int \prod_{k'\not= k}e^{-e^{-(z_k-x_{k'})}}\cdot e^{-(z_k-x_k)-e^{-(z_k-x_k)}}\,dz_k \\ \qquad \qquad = \int e^{-\sum_{k'\not=k}e^{-(z_k-x_{k'})}-(z_k-x_k)-e^{-(z_k-x_k)}}\,dz_k\\ \qquad \qquad = \int e^{-\sum_{k'=1}^Ke^{-(z_k-x_{k'})}-(z_k-x_k)}\,dz_k\\ \qquad \qquad = \int e^{-(\sum_{k'=1}^Ke^{x_{k'}})e^{-z_k}-z_k+x_k}\,dz_k\\ \qquad \qquad = \int e^{-e^{-z_k+\ln(\sum_{k'=1}^Ke^{x_{k'}})}-z_k+x_k}\,dz_k \\ \qquad \qquad = \int e^{-e^{-(z_k-\ln(\sum_{k'=1}^Ke^{x_{k'}}))}-(z_k-\ln(\sum_{k'=1}^Ke^{x_{k'}}))-\ln(\sum_{k'=1}^Ke^{x_{k'}})+x_k}\,dz_k \\ \qquad \qquad = e^{-\ln(\sum_{k'=1}^Ke^{x_{k'}})+x_k}\int e^{-e^{-(z_k-\ln(\sum_{k'=1}^Ke^{x_{k'}}))}-(z_k-\ln(\sum_{k'=1}^Ke^{x_{k'}}))}\,dz_k\\ \qquad \qquad = \frac{e^{x_k}}{\sum_{k'=1}^Ke^{x_{k'}}}\int e^{-e^{-(z_k-\ln(\sum_{k'=1}^Ke^{x_{k'}}))}-(z_k-\ln(\sum_{k'=1}^Ke^{x_{k'}}))}\,dz_k \\ \qquad \qquad = \frac{e^{x_k}}{\sum_{k'=1}^Ke^{x_{k'}}}\int e^{-(z_k-\ln(\sum_{k'=1}^Ke^{x_{k'}}))-e^{-(z_k-\ln(\sum_{k'=1}^Ke^{x_{k'}}))}}\,dz_k \end{array} \]

  積分里面是\(\mu=\ln(\sum_{k'=1}^Ke^{x_{k'}})\)的Gumbel分布,所以整個積分為1。則有

\[P(z_k\ge z_{k'};\forall k'\not = k|\{x_{k'}\}_{k'=1}^K)=\frac{e^{x_k}}{\sum_{k'=1}^Ke^{x_{k'}}} \]

  這和softmax的結果一致。


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM