TensorFlow NMT的數據處理過程


在tensorflow/nmt項目中,訓練數據和推斷數據的輸入使用了新的Dataset API,應該是tensorflow 1.2之后引入的API,方便數據的操作。如果你還在使用老的Queue和Coordinator的方式,建議升級高版本的tensorflow並且使用Dataset API。

本教程將從訓練數據和推斷數據兩個方面,詳解解析數據的具體處理過程,你將看到文本數據如何轉化為模型所需要的實數,以及中間的張量的維度是怎么樣的,batch_size和其他超參數又是如何作用的。

訓練數據的處理

先來看看訓練數據的處理。訓練數據的處理比推斷數據的處理稍微復雜一些,弄懂了訓練數據的處理過程,就可以很輕松地理解推斷數據的處理。
訓練數據的處理代碼位於nmt/utils/iterator_utils.py文件內的get_iterator函數。

函數的參數

我們先來看看這個函數所需要的參數是什么意思:

參數 解釋
src_dataset 源數據集
tgt_dataset 目標數據集
src_vocab_table 源數據單詞查找表,就是個單詞和int類型數據的對應表
tgt_vocab_table 目標數據單詞查找表,就是個單詞和int類型數據的對應表
batch_size 批大小
sos 句子開始標記
eos 句子結尾標記
random_seed 隨機種子,用來打亂數據集的
num_buckets 桶數量
src_max_len 源數據最大長度
tgt_max_len 目標數據最大長度
num_parallel_calls 並發處理數據的並發數
output_buffer_size 輸出緩沖區大小
skip_count 跳過數據行數
num_shards 將數據集分片的數量,分布式訓練中有用
shard_index 數據集分片后的id
reshuffle_each_iteration 是否每次迭代都重新打亂順序

上面的解釋,如果有不清楚的,可以查看我之前一片介紹超參數的文章:
tensorflow_nmt的超參數詳解

我們首先搞清楚幾個重要的參數是怎么來的。
src_datasettgt_dataset是我們的訓練數據集,他們是逐行一一對應的。比如我們有兩個文件src_data.txttgt_data.txt分別對應訓練數據的源數據和目標數據,那么它們的Dataset如何創建的呢?其實利用Dataset API很簡單:

src_dataset=tf.data.TextLineDataset('src_data.txt') tgt_dataset=tf.data.TextLineDataset('tgt_data.txt') 

這就是上述函數中的兩個參數src_datasettgt_dataset的由來。

src_vocab_tabletgt_vocab_table是什么呢?同樣顧名思義,就是這兩個分別代表源數據詞典的查找表和目標數據詞典的查找表,實際上查找表就是一個字符串到數字的映射關系。當然,如果我們的源數據和目標數據使用的是同一個詞典,那么這兩個查找表的內容是一模一樣的。很容易想到,肯定也有一種數字到字符串的映射表,這是肯定的,因為神經網絡的數據是數字,而我們需要的目標數據是字符串,因此它們之間肯定有一個轉換的過程,這個時候,就需要我們的reverse_vocab_table來作用了。

我們看看這兩個表是怎么構建出來的呢?代碼很簡單,利用tensorflow庫中定義的lookup_ops即可:

def create_vocab_tables(src_vocab_file, tgt_vocab_file, share_vocab): """Creates vocab tables for src_vocab_file and tgt_vocab_file.""" src_vocab_table = lookup_ops.index_table_from_file( src_vocab_file, default_value=UNK_ID) if share_vocab: tgt_vocab_table = src_vocab_table else: tgt_vocab_table = lookup_ops.index_table_from_file( tgt_vocab_file, default_value=UNK_ID) return src_vocab_table, tgt_vocab_table 

我們可以發現,創建這兩個表的過程,就是將詞典中的每一個詞,對應一個數字,然后返回這些數字的集合,這就是所謂的詞典查找表。效果上來說,就是對詞典中的每一個詞,從0開始遞增的分配一個數字給這個詞。

那么到這里你有可能會有疑問,我們詞典中的詞和我們自定義的標記sos等是不是有可能被映射為同一個整數而造成沖突?這個問題該如何解決?聰明如你,這個問題是存在的。那么我們的項目是如何解決的呢?很簡單,那就是將我們自定義的標記當成詞典的單詞,然后加入到詞典文件中,這樣一來,lookup_ops操作就把標記當成單詞處理了,也就就解決了沖突!

具體的過程,本文后面會有一個例子,可以為您呈現具體過程。
如果我們指定了share_vocab參數,那么返回的源單詞查找表和目標單詞查找表是一樣的。我們還可以指定一個default_value,在這里是UNK_ID,實際上就是0。如果不指定,那么默認值為-1。這就是查找表的創建過程。如果你想具體的知道其代碼實現,可以跳轉到tensorflow的C++核心部分查看代碼(使用PyCharm或者類似的IDE)。

數據集的處理過程

該函數處理訓練數據的主要代碼如下:

if not output_buffer_size: output_buffer_size = batch_size * 1000 src_eos_id = tf.cast(src_vocab_table.lookup(tf.constant(eos)), tf.int32) tgt_sos_id = tf.cast(tgt_vocab_table.lookup(tf.constant(sos)), tf.int32) tgt_eos_id = tf.cast(tgt_vocab_table.lookup(tf.constant(eos)), tf.int32) src_tgt_dataset = tf.data.Dataset.zip((src_dataset, tgt_dataset)) src_tgt_dataset = src_tgt_dataset.shard(num_shards, shard_index) if skip_count is not None: src_tgt_dataset = src_tgt_dataset.skip(skip_count) src_tgt_dataset = src_tgt_dataset.shuffle( output_buffer_size, random_seed, reshuffle_each_iteration) src_tgt_dataset = src_tgt_dataset.map( lambda src, tgt: ( tf.string_split([src]).values, tf.string_split([tgt]).values), num_parallel_calls=num_parallel_calls).prefetch(output_buffer_size) # Filter zero length input sequences. src_tgt_dataset = src_tgt_dataset.filter( lambda src, tgt: tf.logical_and(tf.size(src) > 0, tf.size(tgt) > 0)) if src_max_len: src_tgt_dataset = src_tgt_dataset.map( lambda src, tgt: (src[:src_max_len], tgt), num_parallel_calls=num_parallel_calls).prefetch(output_buffer_size) if tgt_max_len: src_tgt_dataset = src_tgt_dataset.map( lambda src, tgt: (src, tgt[:tgt_max_len]), num_parallel_calls=num_parallel_calls).prefetch(output_buffer_size) # Convert the word strings to ids. Word strings that are not in the # vocab get the lookup table's default_value integer. src_tgt_dataset = src_tgt_dataset.map( lambda src, tgt: (tf.cast(src_vocab_table.lookup(src), tf.int32), tf.cast(tgt_vocab_table.lookup(tgt), tf.int32)), num_parallel_calls=num_parallel_calls).prefetch(output_buffer_size) # Create a tgt_input prefixed with <sos> and a tgt_output suffixed with <eos>. src_tgt_dataset = src_tgt_dataset.map( lambda src, tgt: (src, tf.concat(([tgt_sos_id], tgt), 0), tf.concat((tgt, [tgt_eos_id]), 0)), num_parallel_calls=num_parallel_calls).prefetch(output_buffer_size) # Add in sequence lengths. src_tgt_dataset = src_tgt_dataset.map( lambda src, tgt_in, tgt_out: ( src, tgt_in, tgt_out, tf.size(src), tf.size(tgt_in)), num_parallel_calls=num_parallel_calls).prefetch(output_buffer_size) 

我們逐步來分析,這個過程到底做了什么,數據張量又是如何變化的。

我們知道,對於源數據和目標數據,每一行數據,我們都可以使用一些標記來表示數據的開始和結束,在本項目中,我們可以通過soseos兩個參數指定句子開始標記和結束標記,默認值分別為**和**。本部分代碼一開始就是將這兩個句子標記表示成一個整數,代碼如下:

src_eos_id = tf.cast(src_vocab_table.lookup(tf.constant(eos)), tf.int32) tgt_sos_id = tf.cast(tgt_vocab_table.lookup(tf.constant(sos)), tf.int32) tgt_eos_id = tf.cast(tgt_vocab_table.lookup(tf.constant(eos)), tf.int32) 

過程很簡單,就是通過兩個字符串到整形的查找表,根據soseos的字符串,找到對應的整數,用改整數來表示這兩個標記,並且將這兩個整數轉型為int32類型。
接下來做的是一些常規操作,解釋如注釋:

# 通過zip操作將源數據集和目標數據集合並在一起 # 此時的張量變化 [src_dataset] + [tgt_dataset] ---> [src_dataset, tgt_dataset] src_tgt_dataset = tf.data.Dataset.zip((src_dataset, tgt_dataset)) # 數據集分片,分布式訓練的時候可以分片來提高訓練速度 src_tgt_dataset = src_tgt_dataset.shard(num_shards, shard_index) if skip_count is not None: # 跳過數據,比如一些文件的頭尾信息行 src_tgt_dataset = src_tgt_dataset.skip(skip_count) # 隨機打亂數據,切斷相鄰數據之間的聯系 # 根據文檔,該步驟要盡早完成,完成該步驟之后在進行其他的數據集操作 src_tgt_dataset = src_tgt_dataset.shuffle( output_buffer_size, random_seed, reshuffle_each_iteration) 

接下來就是重點了,我將用注釋的形式給大家解釋:

  # 將每一行數據,根據“空格”切分開來 # 這個步驟可以並發處理,用num_parallel_calls指定並發量 # 通過prefetch來預獲取一定數據到緩沖區,提升數據吞吐能力 # 張量變化舉例 ['上海 浦東', '上海 浦東'] ---> [['上海', '浦東'], ['上海', '浦東']] src_tgt_dataset = src_tgt_dataset.map( lambda src, tgt: ( tf.string_split([src]).values, tf.string_split([tgt]).values), num_parallel_calls=num_parallel_calls).prefetch(output_buffer_size) # 過濾掉長度為0的數據 src_tgt_dataset = src_tgt_dataset.filter( lambda src, tgt: tf.logical_and(tf.size(src) > 0, tf.size(tgt) > 0))  # 限制源數據最大長度 if src_max_len: src_tgt_dataset = src_tgt_dataset.map( lambda src, tgt: (src[:src_max_len], tgt), num_parallel_calls=num_parallel_calls).prefetch(output_buffer_size)  # 限制目標數據的最大長度 if tgt_max_len: src_tgt_dataset = src_tgt_dataset.map( lambda src, tgt: (src, tgt[:tgt_max_len]), num_parallel_calls=num_parallel_calls).prefetch(output_buffer_size) # 通過map操作將字符串轉換為數字 # 張量變化舉例 [['上海', '浦東'], ['上海', '浦東']] ---> [[1, 2], [1, 2]] src_tgt_dataset = src_tgt_dataset.map( lambda src, tgt: (tf.cast(src_vocab_table.lookup(src), tf.int32), tf.cast(tgt_vocab_table.lookup(tgt), tf.int32)), num_parallel_calls=num_parallel_calls).prefetch(output_buffer_size) # 給目標數據加上 sos, eos 標記 # 張量變化舉例 [[1, 2], [1, 2]] ---> [[1, 2], [sos_id, 1, 2], [1, 2, eos_id]] src_tgt_dataset = src_tgt_dataset.map( lambda src, tgt: (src, tf.concat(([tgt_sos_id], tgt), 0), tf.concat((tgt, [tgt_eos_id]), 0)), num_parallel_calls=num_parallel_calls).prefetch(output_buffer_size) # 增加長度信息 # 張量變化舉例 [[1, 2], [sos_id, 1, 2], [1, 2, eos_id]] ---> [[1, 2], [sos_id, 1, 2], [1, 2, eos_id], [src_size], [tgt_size]] src_tgt_dataset = src_tgt_dataset.map( lambda src, tgt_in, tgt_out: ( src, tgt_in, tgt_out, tf.size(src), tf.size(tgt_in)), num_parallel_calls=num_parallel_calls).prefetch(output_buffer_size) 

其實到這里,基本上數據已經處理好了,可以拿去訓練了。但是有一個問題,那就是我們的每一行數據長度大小不一。這樣拿去訓練其實是需要很大的運算量的,那么有沒有辦法優化一下呢?有的,那就是數據對齊處理。

如何對齊數據

數據對齊的代碼如下,使用注釋的方式來解釋代碼:

# 參數x實際上就是我們的 dataset 對象 def batching_func(x): # 調用dataset的padded_batch方法,對齊的同時,也對數據集進行分批 return x.padded_batch( batch_size, # 對齊數據的形狀 padded_shapes=( # 因為數據長度不定,因此設置None tf.TensorShape([None]), # src # 因為數據長度不定,因此設置None tf.TensorShape([None]), # tgt_input # 因為數據長度不定,因此設置None tf.TensorShape([None]), # tgt_output # 數據長度張量,實際上不需要對齊 tf.TensorShape([]), # src_len tf.TensorShape([])), # tgt_len # 對齊數據的值 padding_values=( # 用src_eos_id填充到 src 的末尾 src_eos_id, # src # 用tgt_eos_id填充到 tgt_input 的末尾 tgt_eos_id, # tgt_input # 用tgt_eos_id填充到 tgt_output 的末尾 tgt_eos_id, # tgt_output 0, # src_len -- unused 0)) # tgt_len -- unused 

這樣就完成了數據的對齊,並且將數據集按照batch_size完成了分批。

num_buckets分桶到底起什么作用

num_buckets起作用的代碼如下:  

  if num_buckets > 1: def key_func(unused_1, unused_2, unused_3, src_len, tgt_len): # Calculate bucket_width by maximum source sequence length. # Pairs with length [0, bucket_width) go to bucket 0, length # [bucket_width, 2 * bucket_width) go to bucket 1, etc. Pairs with length # over ((num_bucket-1) * bucket_width) words all go into the last bucket. if src_max_len: bucket_width = (src_max_len + num_buckets - 1) // num_buckets else: bucket_width = 10 # Bucket sentence pairs by the length of their source sentence and target # sentence. bucket_id = tf.maximum(src_len // bucket_width, tgt_len // bucket_width) return tf.to_int64(tf.minimum(num_buckets, bucket_id)) def reduce_func(unused_key, windowed_data): return batching_func(windowed_data) batched_dataset = src_tgt_dataset.apply( tf.contrib.data.group_by_window( key_func=key_func, reduce_func=reduce_func, window_size=batch_size)) else: batched_dataset = batching_func(src_tgt_dataset) 

num_buckets顧名思義就是桶的數量,那么這個桶用來干嘛呢?我們先看看上面兩個函數到底做了什么。
首先是判斷我們指定的參數num_buckets是否大於1,如果是那么就進入到上述的作用過程。

key_func是做什么的呢?通過源碼和注釋我們發現,它是用來將我們的數據集(由源數據和目標數據成對組成)按照一定的方式進行分類的。具體說來就是,根據我們數據集每一行的數據長度,將它放到合適的桶里面去,然后返回該數據所在桶的索引。

這個分桶的過程很簡單。假設我們有一批數據,他們的長度分別為3 8 11 16 20 21,我們規定一個bucket_width為10,那么我們的數據分配到具體的桶的情況是怎么樣的呢?因為桶的寬度為10,所以第一個桶放的是小於長度10的數據,第二個桶放的是10-20之間的數據,以此類推。

所以,要進行分桶,我們需要知道數據和bucket_width兩個條件。然后根據一定的簡單計算,即可確定如何分桶。上述代碼首先根據src_max_len來計算bucket_width,然后分桶,然后返回數據分到的桶的索引。就是這么簡單的一個過程。

那么,你或許有疑問了,我干嘛要分桶呢?你仔細回想下剛剛的過程,是不是發現長度差不多的數據都分到相同的桶里面去了!沒錯,這就是我們分桶的目的,相似長度的數據放在一起,能夠提升計算效率!!!

然后要看第二個函數reduce_func,這個函數做了什么呢?其實就做了一件事情,就是把剛剛分桶好的數據,做一個對齊!!!

那么通過分桶和對齊操作之后,我們的數據集就已經成為了一個對齊(也就是說有固定長度)的數據集了!

回到一開始,如果我們的參數num_bucktes不滿足條件呢?那就直接做對齊操作!看代碼便知!
至此,分桶的過程和作用你已經清楚了。


至此,數據處理已經結束了。接下來就可以從處理好的數據集獲取一批一批的數據來訓練了。
那么如何一批一批獲取數據呢?答案是使用迭代器。獲取Dataset的迭代器很簡單,tensorflow提供了API,代碼如下:

  batched_iter = batched_dataset.make_initializable_iterator() (src_ids, tgt_input_ids, tgt_output_ids, src_seq_len, tgt_seq_len) = (batched_iter.get_next()) 

通過迭代器的get_next()方法,就可以獲取之前我們處理好的批量數據啦!


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM