導出MNIST的數據集


在TensorFlow的官方入門課程中,多次用到mnist數據集。

mnist數據集是一個數字手寫體圖片庫,但它的存儲格式並非常見的圖片格式,所有的圖片都集中保存在四個擴展名為idx3-ubyte的二進制文件。

如果我們想要知道大名鼎鼎的mnist手寫體數字都長什么樣子,就需要從mnist數據集中導出手寫體數字圖片。了解這些手寫體的總體形狀,也有助於加深我們對TensorFlow入門課程的理解。

下面先給出通過TensorFlow api接口導出mnist手寫體數字圖片的python代碼,再對代碼進行分析。代碼在win7下測試通過,linux環境也可以參考本處代碼。

(非常良心的注釋和打印有木有)

 

[python]  view plain  copy
 
  1. #!/usr/bin/python3.5  
  2. # -*- coding: utf-8 -*-  
  3.   
  4. import os  
  5. import tensorflow as tf  
  6. from tensorflow.examples.tutorials.mnist import input_data  
  7.   
  8. from PIL import Image  
  9.   
  10. # 聲明圖片寬高  
  11. rows = 28  
  12. cols = 28  
  13.   
  14. # 要提取的圖片數量  
  15. images_to_extract = 8000  
  16.   
  17. # 當前路徑下的保存目錄  
  18. save_dir = "./mnist_digits_images"  
  19.   
  20. # 讀入mnist數據  
  21. mnist = input_data.read_data_sets("MNIST_data/", one_hot=False)  
  22.   
  23. # 創建會話  
  24. sess = tf.Session()  
  25.   
  26. # 獲取圖片總數  
  27. shape = sess.run(tf.shape(mnist.train.images))  
  28. images_count = shape[0]  
  29. pixels_per_image = shape[1]  
  30.   
  31. # 獲取標簽總數  
  32. shape = sess.run(tf.shape(mnist.train.labels))  
  33. labels_count = shape[0]  
  34.   
  35. # mnist.train.labels是一個二維張量,為便於后續生成數字圖片目錄名,有必要一維化(后來發現只要把數據集的one_hot屬性設為False,mnist.train.labels本身就是一維)  
  36. #labels = sess.run(tf.argmax(mnist.train.labels, 1))  
  37. labels = mnist.train.labels  
  38.   
  39. # 檢查數據集是否符合預期格式  
  40. if (images_count == labels_count) and (shape.size == 1):  
  41.     print ("數據集總共包含 %s 張圖片,和 %s 個標簽" % (images_count, labels_count))  
  42.     print ("每張圖片包含 %s 個像素" % (pixels_per_image))  
  43.     print ("數據類型:%s" % (mnist.train.images.dtype))  
  44.   
  45.     # mnist圖像數據的數值范圍是[0,1],需要擴展到[0,255],以便於人眼觀看  
  46.     if mnist.train.images.dtype == "float32":  
  47.         print ("准備將數據類型從[0,1]轉為binary[0,255]...")  
  48.         for i in range(0,images_to_extract):  
  49.             for n in range(pixels_per_image):  
  50.                 if mnist.train.images[i][n] != 0:  
  51.                     mnist.train.images[i][n] = 255  
  52.             # 由於數據集圖片數量龐大,轉換可能要花不少時間,有必要打印轉換進度  
  53.             if ((i+1)%50) == 0:  
  54.                 print ("圖像浮點數值擴展進度:已轉換 %s 張,共需轉換 %s 張" % (i+1, images_to_extract))  
  55.   
  56.     # 創建數字圖片的保存目錄  
  57.     for i in range(10):  
  58.         dir = "%s/%s/" % (save_dir,i)  
  59.         if not os.path.exists(dir):  
  60.             print ("目錄 ""%s"" 不存在!自動創建該目錄..." % dir)  
  61.             os.makedirs(dir)  
  62.   
  63.     # 通過python圖片處理庫,生成圖片  
  64.     indices = [for x in range(0, 10)]  
  65.     for i in range(0,images_to_extract):  
  66.         img = Image.new("L",(cols,rows))  
  67.         for m in range(rows):  
  68.             for n in range(cols):  
  69.                 img.putpixel((n,m), int(mnist.train.images[i][n+m*cols]))  
  70.         # 根據圖片所代表的數字label生成對應的保存路徑  
  71.         digit = labels[i]  
  72.         path = "%s/%s/%s.bmp" % (save_dir, labels[i], indices[digit])  
  73.         indices[digit] += 1  
  74.         img.save(path)  
  75.         # 由於數據集圖片數量龐大,保存過程可能要花不少時間,有必要打印保存進度  
  76.         if ((i+1)%50) == 0:  
  77.             print ("圖片保存進度:已保存 %s 張,共需保存 %s 張" % (i+1, images_to_extract))  
  78.       
  79. else:  
  80.     print ("圖片數量和標簽數量不一致!")  

 

上述代碼的實現思路如下:

1.讀入mnist手寫體數據;

2.把數據的值從[0,1]浮點范圍轉化為黑白格式(背景為0-黑色,前景為255-白色);

3.根據mnist.train.labels的內容,生成數字索引,也就是建立每一張圖片和其所代表數字的關聯,由此創建對應的保存目錄;

4.循環遍歷mnist.train.images,把每張圖片的像素數據賦值給python圖片處理庫PIL的Image類實例,再調用Image類的save方法把圖片保存在第3步驟中創建的對應目錄。

 

在運行上述代碼之前,你需要確保本地已經安裝python的圖片處理庫PIL,pip安裝命令如下:

pip3 install Pillow

或 pip install Pillow,取決於你的pip版本。

 

上述python代碼運行后,在當前目錄下會生成mnist_digits_images目錄,在該目錄下,可以看到如下內容:

 

可以看到,我們成功地生成了黑底白字的數字圖片。

如果仔細觀察這些圖片,會看到一些肉眼也難以分辨的數字,譬如:

 

上面這幾個數字是2。想不到吧?

下面這兩個是5(看起來更像6):

這個是7:(7長這樣?有句MMP不知當講不當講)

猜猜下面這個是什么:

 

這是大寫的L?不是。

有點像1,是1嗎?也不是。

倒立拉粑的7?sorry,又猜錯了。

實話告訴您,它是2!一開始我也是不相信的,知道真相的那一刻我下巴差點掉下來!

 

這些手寫圖片,一般人用肉眼觀察,識別率能達到98%就不錯了,但是通過TensorFlow搭建的卷積神經網絡識別率可以達到99%,非常地神奇!


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM