在tensorflow/nmt項目中,訓練數據和推斷數據的輸入使用了新的Dataset API,應該是tensorflow 1.2之后引入的API,方便數據的操作。如果你還在使用老的Queue和Coordinator的方式,建議升級高版本的tensorflow並且使用Dataset API。
本教程將從訓練數據和推斷數據兩個方面,詳解解析數據的具體處理過程,你將看到文本數據如何轉化為模型所需要的實數,以及中間的張量的維度是怎么樣的,batch_size和其他超參數又是如何作用的。
訓練數據的處理
先來看看訓練數據的處理。訓練數據的處理比推斷數據的處理稍微復雜一些,弄懂了訓練數據的處理過程,就可以很輕松地理解推斷數據的處理。
訓練數據的處理代碼位於nmt/utils/iterator_utils.py文件內的get_iterator
函數。
函數的參數
我們先來看看這個函數所需要的參數是什么意思:
參數 | 解釋 |
---|---|
src_dataset |
源數據集 |
tgt_dataset |
目標數據集 |
src_vocab_table |
源數據單詞查找表,就是個單詞和int類型數據的對應表 |
tgt_vocab_table |
目標數據單詞查找表,就是個單詞和int類型數據的對應表 |
batch_size |
批大小 |
sos |
句子開始標記 |
eos |
句子結尾標記 |
random_seed |
隨機種子,用來打亂數據集的 |
num_buckets |
桶數量 |
src_max_len |
源數據最大長度 |
tgt_max_len |
目標數據最大長度 |
num_parallel_calls |
並發處理數據的並發數 |
output_buffer_size |
輸出緩沖區大小 |
skip_count |
跳過數據行數 |
num_shards |
將數據集分片的數量,分布式訓練中有用 |
shard_index |
數據集分片后的id |
reshuffle_each_iteration |
是否每次迭代都重新打亂順序 |
上面的解釋,如果有不清楚的,可以查看我之前一片介紹超參數的文章:
tensorflow_nmt的超參數詳解
我們首先搞清楚幾個重要的參數是怎么來的。src_dataset
和tgt_dataset
是我們的訓練數據集,他們是逐行一一對應的。比如我們有兩個文件src_data.txt
和tgt_data.txt
分別對應訓練數據的源數據和目標數據,那么它們的Dataset如何創建的呢?其實利用Dataset API很簡單:
src_dataset=tf.data.TextLineDataset('src_data.txt') tgt_dataset=tf.data.TextLineDataset('tgt_data.txt')
這就是上述函數中的兩個參數src_dataset
和tgt_dataset
的由來。
src_vocab_table
和tgt_vocab_table
是什么呢?同樣顧名思義,就是這兩個分別代表源數據詞典的查找表和目標數據詞典的查找表,實際上查找表就是一個字符串到數字的映射關系。當然,如果我們的源數據和目標數據使用的是同一個詞典,那么這兩個查找表的內容是一模一樣的。很容易想到,肯定也有一種數字到字符串的映射表,這是肯定的,因為神經網絡的數據是數字,而我們需要的目標數據是字符串,因此它們之間肯定有一個轉換的過程,這個時候,就需要我們的reverse_vocab_table來作用了。
我們看看這兩個表是怎么構建出來的呢?代碼很簡單,利用tensorflow庫中定義的lookup_ops即可:
def create_vocab_tables(src_vocab_file, tgt_vocab_file, share_vocab): """Creates vocab tables for src_vocab_file and tgt_vocab_file.""" src_vocab_table = lookup_ops.index_table_from_file( src_vocab_file, default_value=UNK_ID) if share_vocab: tgt_vocab_table = src_vocab_table else: tgt_vocab_table = lookup_ops.index_table_from_file( tgt_vocab_file, default_value=UNK_ID) return src_vocab_table, tgt_vocab_table
我們可以發現,創建這兩個表的過程,就是將詞典中的每一個詞,對應一個數字,然后返回這些數字的集合,這就是所謂的詞典查找表。效果上來說,就是對詞典中的每一個詞,從0開始遞增的分配一個數字給這個詞。
那么到這里你有可能會有疑問,我們詞典中的詞和我們自定義的標記sos
等是不是有可能被映射為同一個整數而造成沖突?這個問題該如何解決?聰明如你,這個問題是存在的。那么我們的項目是如何解決的呢?很簡單,那就是將我們自定義的標記當成詞典的單詞,然后加入到詞典文件中,這樣一來,lookup_ops
操作就把標記當成單詞處理了,也就就解決了沖突!
具體的過程,本文后面會有一個例子,可以為您呈現具體過程。
如果我們指定了share_vocab
參數,那么返回的源單詞查找表和目標單詞查找表是一樣的。我們還可以指定一個default_value,在這里是UNK_ID
,實際上就是0
。如果不指定,那么默認值為-1
。這就是查找表的創建過程。如果你想具體的知道其代碼實現,可以跳轉到tensorflow的C++核心部分查看代碼(使用PyCharm或者類似的IDE)。
數據集的處理過程
該函數處理訓練數據的主要代碼如下:
if not output_buffer_size: output_buffer_size = batch_size * 1000 src_eos_id = tf.cast(src_vocab_table.lookup(tf.constant(eos)), tf.int32) tgt_sos_id = tf.cast(tgt_vocab_table.lookup(tf.constant(sos)), tf.int32) tgt_eos_id = tf.cast(tgt_vocab_table.lookup(tf.constant(eos)), tf.int32) src_tgt_dataset = tf.data.Dataset.zip((src_dataset, tgt_dataset)) src_tgt_dataset = src_tgt_dataset.shard(num_shards, shard_index) if skip_count is not None: src_tgt_dataset = src_tgt_dataset.skip(skip_count) src_tgt_dataset = src_tgt_dataset.shuffle( output_buffer_size, random_seed, reshuffle_each_iteration) src_tgt_dataset = src_tgt_dataset.map( lambda src, tgt: ( tf.string_split([src]).values, tf.string_split([tgt]).values), num_parallel_calls=num_parallel_calls).prefetch(output_buffer_size) # Filter zero length input sequences. src_tgt_dataset = src_tgt_dataset.filter( lambda src, tgt: tf.logical_and(tf.size(src) > 0, tf.size(tgt) > 0)) if src_max_len: src_tgt_dataset = src_tgt_dataset.map( lambda src, tgt: (src[:src_max_len], tgt), num_parallel_calls=num_parallel_calls).prefetch(output_buffer_size) if tgt_max_len: src_tgt_dataset = src_tgt_dataset.map( lambda src, tgt: (src, tgt[:tgt_max_len]), num_parallel_calls=