TensorFlow從1到2(四)時尚單品識別和保存、恢復訓練數據


Fashion Mnist --- 一個圖片識別的延伸案例

在TensorFlow官方新的教程中,第一個例子使用了由MNIST延伸而來的新程序。
這個程序使用一組時尚單品的圖片對模型進行訓練,比如T恤(T-shirt)、長褲(Trouser),訓練完成后,對於給定圖片,可以識別出單品的名稱。

程序同樣將所有圖片規范為28x28點陣,使用灰度圖,每個字節取值范圍0-255。時尚單品的類型,同樣也是分為10類,跟手寫數字識別的分類維度相同。因此實際上,這個例子看起來美觀也有趣很多,但是在技術層面上,跟傳統的MNIST沒有區別。
不同的地方也有,首先是識別之后需要顯示的是單品名稱,而不是0-9的數字,所以程序中需要定義一個標簽數組,並在顯示時做一個轉換:

	......
# 標簽列表
class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 
               'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']
	......
# 顯示標簽名稱
plt.xlabel(class_names[train_labels[i]])
	......


其次,從樣本圖片中你應當能看出來,圖片的復雜度,比手寫數字還是高多了。從而造成的混淆和誤判,顯然也高的多。這種情況下,只使用tf.argmax()獲取確定的一個標簽就有點不足了。所以在這個例子中,增加了使用直方圖,顯示所有10個預測分類中,每個分類的相似度功能。同時,預測正確的,用藍色字體表示。預測結果同樣本標注不同的,使用紅色字體表示。

	......
def plot_value_array(i, predictions_array, true_label):
    predictions_array, true_label = predictions_array[i], true_label[i]
    plt.grid(False)
    plt.xticks([])
    plt.yticks([])
    thisplot = plt.bar(range(10), predictions_array, color="#777777")
    plt.ylim([0, 1])
    predicted_label = tf.argmax(predictions_array)

    thisplot[predicted_label].set_color('red')
    thisplot[true_label].set_color('blue')
	......
plot_value_array(i, predictions, test_labels)
	......	

完整的代碼如下:

#!/usr/bin/env python3

from __future__ import absolute_import, division, print_function

# TensorFlow and tf.keras
import tensorflow as tf
from tensorflow import keras

# Helper libraries
import numpy as np
import matplotlib.pyplot as plt

# 顯示樣本集中,指定圖片、預測信息、標注信息
def plot_image(i, predictions_array, true_label, img):
    predictions_array, true_label, img = predictions_array[i], true_label[i], img[i]
    plt.grid(False)
    plt.xticks([])
    plt.yticks([])

    plt.imshow(img, cmap=plt.cm.binary)

    predicted_label = tf.argmax(predictions_array)
    if predicted_label == true_label:
        color = 'blue'
    else:
        color = 'red'
  
    plt.xlabel("{} {:2.0f}% ({})".format(class_names[predicted_label],
                                         100*np.max(predictions_array),
                                         class_names[true_label]),
                                         color=color)


# 使用柱狀圖顯示預測結果數組,每一個柱狀圖,代表圖片屬於該類的可能性
def plot_value_array(i, predictions_array, true_label):
    predictions_array, true_label = predictions_array[i], true_label[i]
    plt.grid(False)
    plt.xticks([])
    plt.yticks([])
    thisplot = plt.bar(range(10), predictions_array, color="#777777")
    plt.ylim([0, 1])
    predicted_label = tf.argmax(predictions_array)

    thisplot[predicted_label].set_color('red')
    thisplot[true_label].set_color('blue')

# 加載Fashion Mnist數據集,第一次執行的時候會自動從網上下載,這個速度會比較慢
fashion_mnist = keras.datasets.fashion_mnist
(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()

# 如同數字識別的0-9十類,這里也將時尚潮品分了以下十類
# 所以本質上,這跟手寫數字的識別是完全一致的
class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 
               'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']

# 數據規范化,將圖片數據轉化為0-1之間的浮點數字
train_images = train_images / 255.0
test_images = test_images / 255.0

# 為了有一個直觀印象,我們把訓練集前24個樣本圖片顯示在屏幕上,同時顯示圖片的標注信息
# 你可能注意到了,我們在顯示圖片的時候,並沒有跟前面顯示手寫字體圖片一樣,把圖片的規范化數據還原為0-255,
# 這是因為實際上mathplotlib庫可以直接接受浮點型的圖像數據,
# 我們前面首先還原規范化數據,是為了讓你清楚理解原始數據的格式。
plt.figure(figsize=(8, 6))
for i in range(24):
    plt.subplot(4, 6, i+1)
    plt.grid(False)
    plt.xticks([])
    plt.yticks([])
    plt.imshow(train_images[i], cmap=plt.cm.binary)
    plt.xlabel(class_names[train_labels[i]])
plt.show()

# 定義神經網絡模型,用了一個比較簡單的模型
model = keras.Sequential([
    keras.layers.Flatten(input_shape=(28, 28)),
    keras.layers.Dense(128, activation='relu'),
    keras.layers.Dense(10, activation='softmax')
])

# 采用指定的優化器和損失函數編譯模型
model.compile(optimizer='adam', 
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

# 訓練模型
model.fit(train_images, train_labels, epochs=15)

# 使用測試集數據評估訓練后的模型,並顯示評估結果
test_loss, test_acc = model.evaluate(test_images, test_labels)
print('\nTest accuracy:', test_acc)

#########
# 預測所有測試集數據,用於圖形顯示結果
predictions = model.predict(test_images)

# 以5行x3列顯示測試集前15個樣本的圖片和預測結果
# 正確的預測結果藍色顯示,錯誤的預測信息會紅色顯示
# 每一張圖片的右側,會顯示圖片預測的結果數組,這個數組中,數值最大的,代表最可能的分類
# 或者說,每一個數組元素,都代表圖片屬於對應分類的可能性
num_rows = 5
num_cols = 3
num_images = num_rows*num_cols
plt.figure(figsize=(2*2*num_cols, 2*num_rows))
for i in range(num_images):
    plt.subplot(num_rows, 2*num_cols, 2*i+1)
    plot_image(i, predictions, test_labels, test_images)
    plt.subplot(num_rows, 2*num_cols, 2*i+2)
    plot_value_array(i, predictions, test_labels)
plt.show()

#############
# 演示預測單獨一幅圖片
# 從測試集獲取一幅圖
img = test_images[0]
# 我們的模型是批處理進行預測的,要求的是一個圖片的數組,所以這里擴展一維
# 成為(1, 28, 28)這樣的形式
img = (np.expand_dims(img, 0))
# 使用模型進行預測
predictions_single = model.predict(img)
# 顯示預測結果數組
print("test_images[0] prediction array:", predictions_single)
# 顯示轉換為可識別類型的預測結果
print("test_images[0] prediction text:", class_names[tf.argmax(predictions_single[0])])
# 顯示原標注
print("test_labels[0]:", class_names[test_labels[0]])
# 原圖的顯示請參考上面大圖的左上角第一幅,此處略

程序最后還演示了使用1幅圖片數據調用模型進行預測的方式。特別不要忘記把這一幅圖片擴展一維再進入模型,因為我們的模型是使用批處理方式進行預測的,原本接受的是一個圖片的數組。
程序在第一次執行的時候,會自動由網上下載數據集,下載的網址在下面的顯示信息中能看到。下載完成后,數據會存放在~/.keras/datasets/fashion-mnist/文件夾。

$ ./fashion_mnist.py 
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-labels-idx1-ubyte.gz
32768/29515 [=================================] - 0s 15us/step
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-images-idx3-ubyte.gz
26427392/26421880 [==============================] - 65s 2us/step
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-labels-idx1-ubyte.gz
8192/5148 [===============================================] - 0s 8us/step
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-images-idx3-ubyte.gz
4423680/4422102 [==============================] - 10s 2us/step

以后再運行程序的時候,程序就直接使用本地數據運行。執行過程所顯示的信息類似下面:

$ ./fashion_mnist.py
Epoch 1/15
60000/60000 [==============================] - 4s 68us/sample - loss: 0.4999 - accuracy: 0.8247
Epoch 2/15
60000/60000 [==============================] - 4s 63us/sample - loss: 0.3753 - accuracy: 0.8652
Epoch 3/15
60000/60000 [==============================] - 4s 63us/sample - loss: 0.3361 - accuracy: 0.8783
Epoch 4/15
60000/60000 [==============================] - 4s 64us/sample - loss: 0.3120 - accuracy: 0.8848
Epoch 5/15
60000/60000 [==============================] - 4s 63us/sample - loss: 0.2950 - accuracy: 0.8916
Epoch 6/15
60000/60000 [==============================] - 4s 63us/sample - loss: 0.2825 - accuracy: 0.8950
Epoch 7/15
60000/60000 [==============================] - 4s 64us/sample - loss: 0.2681 - accuracy: 0.9004
Epoch 8/15
60000/60000 [==============================] - 4s 63us/sample - loss: 0.2564 - accuracy: 0.9052
Epoch 9/15
60000/60000 [==============================] - 4s 63us/sample - loss: 0.2463 - accuracy: 0.9088
Epoch 10/15
60000/60000 [==============================] - 4s 64us/sample - loss: 0.2385 - accuracy: 0.9118
Epoch 11/15
60000/60000 [==============================] - 5s 79us/sample - loss: 0.2299 - accuracy: 0.9145
Epoch 12/15
60000/60000 [==============================] - 4s 72us/sample - loss: 0.2224 - accuracy: 0.9165
Epoch 13/15
60000/60000 [==============================] - 4s 65us/sample - loss: 0.2152 - accuracy: 0.9192
Epoch 14/15
60000/60000 [==============================] - 4s 64us/sample - loss: 0.2093 - accuracy: 0.9214
Epoch 15/15
60000/60000 [==============================] - 4s 64us/sample - loss: 0.2031 - accuracy: 0.9227
10000/10000 [==============================] - 0s 38us/sample - loss: 0.3361 - accuracy: 0.8889

Test accuracy: 0.8889
test_images[0] prediction array: [[2.8952907e-09 4.0831842e-06 9.7278274e-08 1.6851689e-09 5.8218838e-08
  3.0680697e-03 1.2691763e-07 1.8435927e-02 3.7783199e-08 9.7849166e-01]]
test_images[0] prediction text: Ankle boot
test_labels[0]: Ankle boot

程序執行中,測試集前15幅圖片的驗證結果顯示如下:

左下角的圖片出現了明顯的識別錯誤。不過話說回來,以我這種時尚盲人來說,也完全區分不出來這種樣子的涼鞋跟運動鞋有啥區別(手動捂臉),當然圖片的分辨率也是問題之一啦。

保存和恢復訓練數據

TensorFlow 2.0提供了兩種數據保存和恢復的方式。第一種方式是我們在TensorFlow 1.x中經常用的保存模型權重參數的方式。
因為在TensorFlow 2.0中,我們使用了model.fit方法來代替之前使用的訓練循環,所以保存訓練權重數據是使用回調函數的方式完成的。下面舉一個例子:

	...在model.compile之后增加下面代碼...
checkpoint_path = "training_data/cp.ckpt"
checkpoint_dir = os.path.dirname(checkpoint_path)

# 設置自己的回調函數
cp_callback = tf.keras.callbacks.ModelCheckpoint(checkpoint_dir, 
                                                 save_weights_only=True,
                                                 verbose=1)
# 修改fit方法增加回調參數
model.fit(train_images, train_labels, epochs=15,
          callbacks = [cp_callback])  
	......

這樣在每一個訓練周期,都會將訓練數據寫入到文件,屏幕顯示會類似這樣:

Epoch 1/15
60000/60000 [==============================] - 4s 68us/sample - loss: 0.4999 - accuracy: 0.8247
Epoch 00001: saving model to training_data/cp.ckpt
Epoch 2/15
60000/60000 [==============================] - 4s 63us/sample - loss: 0.3753 - accuracy: 0.8652
Epoch 00002: saving model to training_data/cp.ckpt
Epoch 3/15
60000/60000 [==============================] - 4s 63us/sample - loss: 0.3361 - accuracy: 0.8783
Epoch 00003: saving model to training_data/cp.ckpt
Epoch 4/15
	......

對於稍大的數據集和稍微復雜的模型,訓練的時間會非常之長。通常我們都會把這種工作部署到有強大算力的服務器上執行。訓練完成,將訓練數據保存下來。預測的時候,則並不需要很大的運算量,就可以在普通的設備上執行了。
還原保存的數據,其實就是把fit方法這一句,替換為加載保存的數據就可以:

	...替代model.fit那一行代碼...
model.load_weights(checkpoint_dir)
	...然后就可以當做訓練完成的模型一樣進行預測操作了...

這種方法是比較多用的,因為很多情況下,我們訓練所使用的模型,跟預測所使用的模型,會有細微的調整。這時候只載入模型的權重值,並不影響模型的微調。
此外,上面的代碼僅為示例。在實際應用中,這種不改變文件名、只保存一組文件的形式,實際並不需要回調函數,在訓練完成后一次寫入到文件是更好的選擇。使用回調函數通常都是為了保存每一步的訓練結果。

保存完整模型

如果模型是比較成熟穩定的,我們很可能喜歡完整的保存整個模型,這樣不僅操作容易,而且也省去了重新建模的工作。Keras內置的vgg-19/resnet50等模型,實際就使用了這種方式,我們會在下一篇詳細介紹。
保存完整的模型非常簡單,只要在model.fit執行完成后,一行代碼就可以保存完整、包含權重參數的模型:

# 將完整模型保存為HDF5文件
model.save('fashion_mnist.h5')

還原完整模型的話,則可以從使用keras.Sequential開始定義模型、模型編譯都不需要,直接使用:

new_model = keras.models.load_model('fashion_mnist.h5')

接着就可以使用new_model這個模型進行預測了。

(待續...)


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM