從TensorFlow的mnist數據集導出手寫體數字圖片


在TensorFlow的官方入門課程中,多次用到mnist數據集。

mnist數據集是一個數字手寫體圖片庫,但它的存儲格式並非常見的圖片格式,所有的圖片都集中保存在四個擴展名為idx3-ubyte的二進制文件。

如果我們想要知道大名鼎鼎的mnist手寫體數字都長什么樣子,就需要從mnist數據集中導出手寫體數字圖片。了解這些手寫體的總體形狀,也有助於加深我們對TensorFlow入門課程的理解。

下面先給出通過TensorFlow api接口導出mnist手寫體數字圖片的python代碼,再對代碼進行分析。代碼在win7下測試通過,linux環境也可以參考本處代碼。

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import os
from PIL import Image

# 聲明圖片寬高
rows = 28
cols = 28

# 要提取的圖片數量
images_to_extract = 8000

# 當前路徑下的保存目錄
save_dir = "./mnist_digits_images"

# 讀入mnist數據
mnist = input_data.read_data_sets("C:\\Users\\Administrator\\Desktop\\Tensorflow\\數據集\\mnist\\", one_hot=False)

# 創建會話
sess = tf.Session()

# 獲取圖片總數
shape = sess.run(tf.shape(mnist.train.images))
images_count = shape[0]
pixels_per_image = shape[1]

# 獲取標簽總數
shape = sess.run(tf.shape(mnist.train.labels))
labels_count = shape[0]

# mnist.train.labels是一個二維張量,為便於后續生成數字圖片目錄名,有必要一維化(后來發現只要把數據集的one_hot屬性設為False,mnist.train.labels本身就是一維)
# labels = sess.run(tf.argmax(mnist.train.labels, 1))
labels = mnist.train.labels

# 檢查數據集是否符合預期格式
if (images_count == labels_count) and (shape.size == 1):
    print("數據集總共包含 %s 張圖片,和 %s 個標簽" % (images_count, labels_count))
    print("每張圖片包含 %s 個像素" % (pixels_per_image))
    print("數據類型:%s" % (mnist.train.images.dtype))

    # mnist圖像數據的數值范圍是[0,1],需要擴展到[0,255],以便於人眼觀看
    if mnist.train.images.dtype == "float32":
        print("准備將數據類型從[0,1]轉為binary[0,255]...")
        for i in range(0, images_to_extract):
            for n in range(pixels_per_image):
                if mnist.train.images[i][n] != 0:
                    mnist.train.images[i][n] = 255
            # 由於數據集圖片數量龐大,轉換可能要花不少時間,有必要打印轉換進度
            if ((i + 1) % 50) == 0:
                print("圖像浮點數值擴展進度:已轉換 %s 張,共需轉換 %s 張" % (i + 1, images_to_extract))

    # 創建數字圖片的保存目錄
    for i in range(10):
        dir = "%s/%s/" % (save_dir, i)
        if not os.path.exists(dir):
            print("目錄 ""%s"" 不存在!自動創建該目錄..." % dir)
            os.makedirs(dir)

    # 通過python圖片處理庫,生成圖片
    indices = [0 for x in range(0, 10)]
    for i in range(0, images_to_extract):
        img = Image.new("L", (cols, rows))
        for m in range(rows):
            for n in range(cols):
                img.putpixel((n, m), int(mnist.train.images[i][n + m * cols]))
        # 根據圖片所代表的數字label生成對應的保存路徑
        digit = labels[i]
        path = "%s/%s/%s.bmp" % (save_dir, labels[i], indices[digit])
        indices[digit] += 1
        img.save(path)
        # 由於數據集圖片數量龐大,保存過程可能要花不少時間,有必要打印保存進度
        if ((i + 1) % 50) == 0:
            print("圖片保存進度:已保存 %s 張,共需保存 %s 張" % (i + 1, images_to_extract))
else:
    print("圖片數量和標簽數量不一致!")

上述代碼的實現思路如下:

1.讀入mnist手寫體數據;

2.把數據的值從[0,1]浮點范圍轉化為黑白格式(背景為0-黑色,前景為255-白色);

3.根據mnist.train.labels的內容,生成數字索引,也就是建立每一張圖片和其所代表數字的關聯,由此創建對應的保存目錄;

4.循環遍歷mnist.train.images,把每張圖片的像素數據賦值給python圖片處理庫PIL的Image類實例,再調用Image類的save方法把圖片保存在第3步驟中創建的對應目錄。

 

轉自https://www.jb51.net/article/166936.htm


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM