Tensorflow實現對fashion mnist(衣服,褲子等圖片)數據集的softmax分類


首先我們要明確的是下面我們講解的是一個很基礎的神經網絡,因為我們只是為了通過下面這個實例來為大家解釋如何使用tensorflow2.0這個框架。整個神經網絡的架構是首先是flatten層(把圖片從二維轉化為一維),然后經過一系列的全連接網絡層,中間穿插着一些dropout層來避免過擬合,最后達到softmax層實現多分類。在整個神經網絡當中並沒有用到卷積神經網絡,卷積神經網絡會在我后面的博文當中寫出。

代碼如下:

import tensorflow as tf
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
#這一次我們使用softmax模型來進行對衣服,褲子,鞋子,包包圖像的分類
(train_image,train_label),(test_image,test_label)=tf.keras.datasets.fashion_mnist.load_data()

加載訓練以及測試的圖片和label標簽完畢,然后查看訓練集圖片的shape:

train_image.shape

輸出:

(60000, 28, 28)

使用plt可以查看單個圖片的式樣:

#用plt交互展示出其中的一個圖像
plt.imshow(test_image[4])

輸出如下:

 

 

 進行數據的歸一化,同時搭建神經網絡:

train_image=train_image/255
test_image=test_image/255#進行數據的歸一化,加快計算的進程

model=tf.keras.Sequential()
model.add(tf.keras.layers.Flatten(input_shape=(28,28)))#因為每一個神經元里只有一個數字,一共有15個input因此這里寫15,
model.add(tf.keras.layers.Dense(200,activation="relu"))
model.add(tf.keras.layers.Dropout(0.5))#添加dropout層,抑制過擬合的效果。
model.add(tf.keras.layers.Dense(300,activation="relu"))
model.add(tf.keras.layers.Dropout(0.5))
model.add(tf.keras.layers.Dense(10,activation="softmax")) #在最后一個節點處使用softmax,因為有十個分類,這里都寫錯了,粗心啊!

#然后確立optimizer
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
              loss="sparse_categorical_crossentropy",
              metrics=['acc']
)
#如果l物體的label使用了順序編碼,那么我們就使用sparse_categorical_crossentropy的loss,使用了獨熱編碼,則使用
#categorical_crossentropy的loss

編譯模型:

history=model.fit(train_image,train_label,epochs=15,validation_data=(test_image,test_label))

輸出:

Train on 60000 samples, validate on 10000 samples
Epoch 1/15
60000/60000 [==============================] - 5s 81us/sample - loss: 0.6924 - acc: 0.7534 - val_loss: 0.5699 - val_acc: 0.7996
Epoch 2/15
60000/60000 [==============================] - 4s 72us/sample - loss: 0.6461 - acc: 0.7688 - val_loss: 0.5634 - val_acc: 0.8051
Epoch 3/15
60000/60000 [==============================] - 4s 75us/sample - loss: 0.6292 - acc: 0.7754 - val_loss: 0.5536 - val_acc: 0.8108
Epoch 4/15
60000/60000 [==============================] - 4s 73us/sample - loss: 0.6199 - acc: 0.7784 - val_loss: 0.5492 - val_acc: 0.8065
Epoch 5/15
60000/60000 [==============================] - 4s 73us/sample - loss: 0.6223 - acc: 0.7772 - val_loss: 0.5447 - val_acc: 0.8121
Epoch 6/15
60000/60000 [==============================] - 4s 73us/sample - loss: 0.6155 - acc: 0.7783 - val_loss: 0.5331 - val_acc: 0.8164
Epoch 7/15
60000/60000 [==============================] - 4s 73us/sample - loss: 0.6053 - acc: 0.7810 - val_loss: 0.5377 - val_acc: 0.8136
Epoch 8/15
60000/60000 [==============================] - 4s 74us/sample - loss: 0.6100 - acc: 0.7821 - val_loss: 0.5338 - val_acc: 0.8220
Epoch 9/15
60000/60000 [==============================] - 4s 74us/sample - loss: 0.6069 - acc: 0.7830 - val_loss: 0.5387 - val_acc: 0.8169
Epoch 10/15
60000/60000 [==============================] - 4s 74us/sample - loss: 0.6020 - acc: 0.7843 - val_loss: 0.5317 - val_acc: 0.8223
Epoch 11/15
60000/60000 [==============================] - 4s 74us/sample - loss: 0.5986 - acc: 0.7856 - val_loss: 0.5314 - val_acc: 0.8196
Epoch 12/15
60000/60000 [==============================] - 5s 78us/sample - loss: 0.5884 - acc: 0.7900 - val_loss: 0.5329 - val_acc: 0.8188
Epoch 13/15
60000/60000 [==============================] - 5s 76us/sample - loss: 0.5959 - acc: 0.7835 - val_loss: 0.5555 - val_acc: 0.8087
Epoch 14/15
60000/60000 [==============================] - 4s 74us/sample - loss: 0.5868 - acc: 0.7871 - val_loss: 0.5269 - val_acc: 0.8304
Epoch 15/15
60000/60000 [==============================] - 5s 75us/sample - loss: 0.5880 - acc: 0.7862 - val_loss: 0.5301 - val_acc: 0.8230

模型訓練完畢,現在把訓練的過程以及結果用plt畫出來,突出acc准確率和loss(損失的大小):

history.history.keys()
plt.plot(history.epoch,history.history.get('loss'),label="loss")
plt.plot(history.epoch,history.history.get('val_loss'),label="val_loss")
plt.legend()

圖像如下:

 

 

模型准確率的圖像如下:

 

 

 

 

 從中可以看出,驗證集的准確率在不斷上升,雖然中途比較跌宕起伏,但是總體有上升的趨勢,因此這個模型可以繼續進行迭代增加模型的驗證集准確率(沒有過擬合的緣故)。


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM