ctc loss


原文地址:

https://zhuanlan.zhihu.com/p/23309693

https://zhuanlan.zhihu.com/p/23293860

 

CTC:前向計算例子

這里我們直接使用warp-ctc中的變量進行分析。我們定義T為RNN輸出的結果的維數,這個問題的最終輸出維度為alphabet_size。而ground_truth的維數為L。也就是說,RNN輸出的結果為alphabet_size*T的結果,我們要將這個結果和1*L這個向量進行對比,求出最終的Loss。

我們要一步一步地揭開這個算法的細節……當然這個算法的實現代碼有點晦澀……

我們的第一步要順着test_cpu.cpp的路線來分析代碼。第一步我們就是要解析small_test()中的內容。也就是做前向計算,計算對於RNN結果來說,對應最終的ground_truth——t的label的概率。

這個計算過程可以用動態規划的算法求解。我們可以用一個變量來表示動態規划的中間過程,它就是:

\alpha^T_i:表示在RNN計算的時間T時刻,這一時刻對應的ground_truth的label為第i個下標的值t[i]的概率。

這樣的表示有點抽象,我們用一個實際的例子來講解:

RNN結果:[R_1,R_2,R_3,R_4],這里的每一個變量都對應一個列向量。

ground_truth:[g_1,g_2,g_3]

那么\alpha^2_1表示R_2的結果對應着g_1的概率,當然與此同時,前面的結果也都合理地對應完成。

從上面的結果我們可以看出,如果R_2的結果對應着g_1,那么R_1的結果也必然對應着g_1。所以前面的結果是確定的。然而對於其他的一些情況來說,我們的轉換存在着一定的不確定性。

CTC:前向計算具體過程

我們還是按照上面的例子進行計算,我們把剛才的例子搬過來:

RNN結果:[R_1,R_2,R_3,R_4],這里的每一個變量都對應一個列向量。

ground_truth:[g_1,g_2,g_3]

alphabet:[g_0(blank),g_1,g_2,g_3]

按照上面介紹的計算方法,第一步我們先做ground_truth的狀態擴展,於是我們就把長度從3擴展到了7,現在的ground_truth變成了:

[blank,g_1,blank,g_2,blank,g_3,blank]

我們的RNN結果長度為4,也就是說我們會從上面的7個ground_truth狀態中進行轉移,並最終轉移到最終狀態。理論上利用動態規划的算法,我們需要計算4*7=28個中間結果。好了,下面我們用P^T_i表示RNN的第T時刻狀態為ground_truth中是第i個位置的概率。

那么我們就開始計算了:

T=1時,我們只能選擇g_1和blank,所以這一輪我們終結狀態只可能落在0和1上。所以第一輪變成了:

[P^1_0,P^1_1,0,0,0,0,0]

T=2時,我們可以繼續選擇g_1,我們同時也可以選擇g_2,還可以選擇g_1g_2之間的blank,所以我們可以進一步關注這三個位置的概率,於是我們將其他的位置的概率設為0。[0,(P^1_0 +P^1_1)P^2_1,P^1_1P^2_2,P^1_1P^2_3,0,0,0]

T=3時,留給我們的時間已經不多了,我們還剩2步,要走完整個旅程,我們只能選擇g_2g_3以及它們之間的空格。於是乎我們關心的位置又發生了變化:

[0,0,0,
(P^1_1P^2_2+P^1_1P^2_3)P^3_3,
P^1_1P^2_3P^3_4,
P^1_1P^2_3P^3_5,
0]

是不是有點看暈了?沒關系,因為還剩最后一步了。下面是最后一步,因為最后一步我們必須要到g_3以及它后面的空格了,所以我們的概率最終計算也就變成了:

[0,0,0,
0,0,
((P^1_1P^2_2+P^1_1P^2_3)P^3_3+P^1_1P^2_2P^3_4+P^1_1P^2_2P^3_3)P^4_5,
P^1_1P^2_3P^3_5P^4_6]

好吧,最終的結果我們求出來了,實際上這就是通過時間的推移不斷迭代求解出來的。關於迭代求解的公式這里就不再贅述了。我們直接來看一張圖:

於是乎我們從這個計算過程中發現一些問題:

首先是一個相對簡單的問題,我們看到在計算過程中我們發現了大量的連乘。由於每一個數字都是浮點數,那么這樣連乘下去,最終數字有可能非常小而導致underflow。所以我們要將這個計算過程轉到對數域上。這樣我們就將其中的乘法轉變成了加法。但是原本就是加法的計算呢?比方說我們現在計算了loga和logb,我們如何計算log(a+b)呢,這里老司機給出了解決方案,我們假設兩個數中a>b,那么有

log(a+b)=log(a(1+\frac{b}{a}))=loga+log(1+\frac{b}{a})
=loga+log(1+exp(log(\frac{b}{a})))=loga+log(1+exp(logb - loga))

這樣我們就利用了loga和logb計算出了log(a+b)來。

另外一個問題就是,我們發現在剛才的計算過程當中,對於每一個時間段,我們實際上並不需要計算每一個ground-truth位置的概率信息,實際上只要計算滿足某個條件的某一部分就可以了。所以我們有沒有希望在計算前就規划好這條路經,以保證我們只計算最相關的那些值呢?

如何控制計算的數量?

不得不說,這一部分warp-ctc寫得實在有點晦澀,當然也可能是我在這方面的理解比較渣。我們這里主要關注兩個部分——一個是數據的准備,一個是最終的數據的使用。

在介紹數據准備之前,我們先簡單說一下這部分計算的大概思路。我們用兩個變量start和end表示我們需要計算的狀態的起止點,在每一個時間點,我們要更新start和end這兩個變量。然后我們更新start和end之間的概率信息。這里我們先要考慮一個問題,start和end的更新有什么規律?

為了簡化思考,我們先假設ground_truth中沒有重復的label,我們的大腦瞬間得到了解放。好了,下面我們就要給出代碼中的兩個變量——

T:表示RNN結果中的維度

S/2:ground_truth的維度(S表示了擴展blank之后的維度)

基本上具備一點常識,我們就可以知道T>=S/2。什么?你覺得有可能出現T<S/2的情況?兄弟,這種見鬼的事情如果發生,你難道要我們把RNN的結果拆開給你用?臣妾不太能做得到啊……

好了,既然接受了上面的事實,那么我們就來舉幾個例子看看:

我們假設T=3,S/2=3,那么說白了,它們之間的對應關系是一一對應,說白了這就和blank位置沒啥關系了。在T=1時,我們要轉移到第一個結果,T=2,我們要轉移到第二個結果……

 

如何控制計算的數量?cont.

好,廢話少說我們書接上回。不明真相的小朋友先看這個:

下面我們假設T=4,S/2=3,好玩的地方來了。T比S/2多一個,也就是說我們允許冗余出現了,那么我們可能的形式也就變多了。我們可以增加一個blank,我們也可以在沒有label位置原地打一輪醬油。選擇更多,歡樂更多。

雖然選擇變多,但是着並不意味着我們可以選擇任意一種狀態轉移的方式,至少:

  • 在T=2時,我們至少要轉移到第一個結果
  • 在T=3時,我們至少要轉移到第二個結果
  • 在T=4時,兄弟我們准備下車了

這其實就是對start的限制。源代碼中有這樣一句話:

int remain = (S / 2) + repeats - (T - t);

這里我們先忽略repeats,那么remain這個變量其實是在計算label數量和剩余時間的差。如果用這樣的語言來表達剛才的那個問題,我們語言就變成這個樣子:

  • 當時間還剩4輪時(包括第4輪),我們在哪都無所謂(實際上是從T=1開始計算的)
  • 當時間還剩3輪時(包括第3輪),我們至少要轉移到第一個結果(index=1)
  • 當時間還剩2輪時(包括第2輪),我們至少要轉移到第二個結果(index=3)
  • 當時間還剩1輪時(包括第1輪),我們至少要轉移到第三個結果(index=5)

好了,這里我們看出其中的含義了。我們再啰嗦一下,看看這些變量隨T的變化情況:

  • T=1,remain=0,start+=1
  • T=2,remain=1,start+=2
  • T=3,remain=2,start+=2

現在我們已經十分清楚了,當remain>=0時,start都要向前走,限制我們計算前面狀態的概率,因為這些概率已經沒有意義了。下面的代碼也是這樣描述的:

if(remain >= 0)
    start += s_inc[remain];

那么這個s_inc是什么東西?它就是我們需要提前准備好的計算量。我們知道經過擴充的label序列中,所有的非空label都處在奇數的index上,而填充的blank都處在偶數的index上(我們是0-based的計算方法,matlab選手請退散……),所以對於上面的問題,當start=0時,下一步我們會從0跳到1,此后我們會從1到3,3到5,跳轉的步數都是2,所以基於這個思路,我們就可以把s_inc這個數組生成出來。當然,我們的前提是沒有重復。下面我們會說重復的問題的。

我們上面說了這么多,重點把start的變化介紹清楚了。下面我們來看看end。其實end的原理也類似,我們還是用剛才的廢話套路來介紹站在end視角的世界:

  • 在T=1時,我們最多能到第一個結果
  • 在T=2時,我們最多能轉移到第二個結果
  • 在T=3時,我們最多能轉移到第三個結果
  • 在T=4時,我們已經掌握了整個世界……oh yeah

好了,可以看出end的變化形式,每個時刻end都可以+2,直到到達最后一個非blank的label,end變成了+1,然后end就不用動了,等着start動就可以了……(怎么感覺有點污?天哪……)

那么end變化的條件是什么呢?

if(t <= (S / 2) + repeats)
    end += e_inc[t - 1];

我們還是忽略repeats,那么就十分清楚了,如果當前時刻小於等於label數,那么盡管前進,如果大於了,基本上也就到頭了,這時候end就不用動了。

好了,前面我們終於說完了簡單模式下start和end的移動規律,下面我們來看看帶重復模式下的變化方法。

重復,重復

重復會帶來什么樣的變化呢?說白了如果有重復的label出現,那么兩個連續重復的label中間就要至少出現一個blank。換句話說,每出現一個重復,我們的S/2就要加一,於是我們再看一眼這兩個計算公式:

int remain = (S / 2) + repeats - (T - t);
if(remain >= 0)
    start += s_inc[remain];
if(t <= (S / 2) + repeats)
    end += e_inc[t - 1];

我們把repeats和S/2歸到一起,這時候就能看明白了。

同理,在計算s_inc和e_inc的時候,由於有repeats的存在,它們從過去的+2變成了兩個+1。也就是說先從label跳到blank,再跳到下一個label。這樣就可以解釋s_inc和e_inc的初始化策略了:

int e_counter = 0;
int s_counter = 0;

s_inc[s_counter++] = 1;

int repeats = 0;

for (int i = 1; i < L; ++i) {
    if (labels[i-1] == labels[i]) {
        s_inc[s_counter++] = 1;
        s_inc[s_counter++] = 1;
        e_inc[e_counter++] = 1;
        e_inc[e_counter++] = 1;
        ++repeats;
    }
    else {
        s_inc[s_counter++] = 2;
        e_inc[e_counter++] = 2;
    }
}
e_inc[e_counter++] = 1;

好了,到此我們才算把CTC中compute ctc loss這部分介紹完了。教科書上的一個公式看着簡單,落實到代碼就似乎充滿了trick。希望看懂了這個計算的你大腦沒有陣亡。


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM