最近用tensorflow寫了個OCR的程序,在實現的過程中,發現自己還是跳了不少坑,在這里做一個記錄,便於以后回憶。主要的內容有lstm+ctc具體的輸入輸出,以及TF中的CTC和百度開源的warpCTC在具體使用中的區別。
正文
輸入輸出
因為我最后要最小化的目標函數就是ctc_loss
,所以下面就從如何構造輸入輸出說起。
tf.nn.ctc_loss
先從TF自帶的tf.nn.ctc_loss說起,官方給的定義如下,因此我們需要做的就是將圖片的label(需要OCR出的結果),圖片,以及圖片的長度
轉換為label,input,和sequence_length。
ctc_loss(
labels,
inputs,
sequence_length,
preprocess_collapse_repeated=False,
ctc_merge_repeated=True,
time_major=True
)
input: 輸入(訓練)數據,是一個三維float型的數據結構
[max_time_step , batch_size , num_classes]
,當修改time_major = False時,
[batch_size,max_time_step,num_classes]
。
總體的數據流:
image_batch
->
->
->
->
->
->
image_batch
->
[batch_size,max_time_step,num_features]
->lstm
->
[batch_size,max_time_step,cell.output_size]
->reshape
->
[batch_size*max_time_step,num_hidden]
->affine projection
A*W+b
->
[batch_size*max_time_step,num_classes]
->reshape
->
[batch_size,max_time_step,num_classes]
->transpose
->
[max_time_step,batch_size,num_classes]
下面詳細解釋一下,
假如一張圖片有如下shape:[60,160,3],我們如果讀取灰度圖則shape=[60,160],此時,我們將其一列作為feature,那么共有60個features,160個time_step,這時假設一個batch為64,那么我們此時獲得到了一個
假如一張圖片有如下shape:[60,160,3],我們如果讀取灰度圖則shape=[60,160],此時,我們將其一列作為feature,那么共有60個features,160個time_step,這時假設一個batch為64,那么我們此時獲得到了一個
[batch_size,max_time_step,num_features] = [64,160,60]
的訓練數據。
然后將該訓練數據送入
構建的lstm網絡中,(需要注意的是
dynamic_rnn的輸入數據在一個batch內的長度是固定的,但是不同batch之間可以不同,我們需要給他一個
得到形如
sequence_length
(長度為batch_size的向量)來記錄本次batch數據的長度,對於OCR這個問題,sequence_length就是長度為64,而值為160的一維向量)
得到形如
[batch_size,max_time_step,cell.output_size]
的輸出,其中cell.output_size == num_hidden。
下面我們需要做一個線性變換將其送入ctc_loos中進行計算,lstm中不同time_step之間共享權值,所以我們只需定義
W
的結構為
[num_hidden,num_classes]
,
b
的結構為[num_classes]。而
tf.matmul操作中,兩個矩陣相乘階數應當匹配,所以我們將上一步的輸出reshape成
[batch_size*max_time_step,num_hidden]
(num_hidden為自己定義的lstm的unit個數)記為
A
,然后將其做一個線性變換,於是
A*w+b
得到形如
[batch_size*max_time_step,num_classes]
然后在reshape回來得到
[batch_size,max_time_step,num_classes]
最后由於ctc_loss的要求,我們再做一次轉置,得到
[max_time_step,batch_size,num_classes]
形狀的數據作為input
labels: 標簽序列
由於OCR的結果是不定長的,所以label實際上是一個稀疏矩陣SparseTensor,
其中:
indices
:二維int64的矩陣,代表非0的坐標點values
:二維tensor,代表indice位置的數據值dense_shape
:一維,代表稀疏矩陣的大小
比如有兩幅圖,分別是123
,和4567
那么
indecs =[[0,0],[0,1],[0,2],[1,0],[1,1],[1,2],[1,3]]
values = [1,2,3,4,5,6,7]
dense_shape = [2,4]
代表dense tensor:12[[1,2,3,0][4,5,6,7]]
seq_len: 在input一節中已經講過,一維數據,[time_step,…,time_step]長度為batch_size,值為time_step