python讀取,顯示,保存mnist圖片


python處理二進制

python的struct模塊可以將整型(或者其它類型)轉化為byte數組.看下面的代碼.

# coding: utf-8
from struct import *
# 包裝成大端的byte數組
print(pack('>hhl', 1, 2, 3))  # b'\x00\x01\x00\x02\x00\x00\x00\x03'

pack('>hhl', 1, 2, 3)作用是以大端的方式把1(h表示2字節整型),2,3(l表示4字節整型),轉化為對於的byte數組.大端小端的區別看參數資料2,>hhl的含義見參考資料1.輸出為長度為8的byte數組,2個h的長度為4,1個l的長度為4,加起來一共是8.
再體會下面代碼的作用.

# coding: utf-8
from struct import *

# 包裝成大端的byte數組
print(pack('>hhl', 1, 2, 3))  # b'\x00\x01\x00\x02\x00\x00\x00\x03'
# 以大端的方式還原成整型
print(unpack('>hhl', b'\x00\x01\x00\x02\x00\x00\x00\x03'))  # (1, 2, 3)


# 包裝成小端的byte數組
print(pack('<hhl', 1, 2, 3))  # b'\x01\x00\x02\x00\x03\x00\x00\x00'
# 以小端的方式還原成整型
print(unpack('<hhl', b'\x00\x01\x00\x02\x00\x00\x00\x03'))  # (256, 512, 50331648)

mnist顯示

以t10k-images.idx3-ubyte為例,t10k-images.idx3-ubyte是二進制文件.其格式為

[offset] [type]          [value]          [description]
0000     32 bit integer  0x00000803(2051) magic number
0004     32 bit integer  10000            number of images
0008     32 bit integer  28               number of rows
0012     32 bit integer  28               number of columns
0016     unsigned byte   ??               pixel
0017     unsigned byte   ??               pixel
........
xxxx     unsigned byte   ??               pixel

前4個整型代表文件頭的一些信息.之后的無符號byte數組才是圖片的內容.所以要先越過前4個整型,然后再開始讀取,代碼如下

import numpy as np
import struct
import matplotlib.pyplot as plt

filename = r'D:\source\technology_source\data\t10k-images.idx3-ubyte'
binfile = open(filename, 'rb')
buf = binfile.read()

index = 0
magic, numImages, numRows, numColumns = struct.unpack_from('>IIII', buf, index)  # 讀取前4個字節的內容
index += struct.calcsize('>IIII')
im = struct.unpack_from('>784B', buf, index)  # 以大端方式讀取一張圖上28*28=784
index += struct.calcsize('>784B')
binfile.close()

im = np.array(im)
im = im.reshape(28, 28)
fig = plt.figure()
plotwindow = fig.add_subplot(111)
plt.axis('off')
plt.imshow(im, cmap='gray')
plt.show()
# plt.savefig("test.png")  # 保存成文件
plt.close()

可以看到結果:

下面是讀取多個圖片並存盤的代碼.

import numpy as np
import struct
import matplotlib.pyplot as plt

filename = r'D:\source\technology_source\data\t10k-images.idx3-ubyte'
binfile = open(filename, 'rb')
buf = binfile.read()

index = 0
magic, numImages, numRows, numColumns = struct.unpack_from('>IIII', buf, index)
index += struct.calcsize('>IIII')

for i in range(30):  # 讀取前30張圖片
    im = struct.unpack_from('>784B', buf, index)
    index += struct.calcsize('>784B')
    im = np.array(im)
    im = im.reshape(28, 28)
    fig = plt.figure()
    plotwindow = fig.add_subplot(111)
    plt.axis('off')
    plt.imshow(im, cmap='gray')
    plt.savefig("test" + str(i) + ".png")
    plt.close()
binfile.close()

另外一種方法

參考tensorflow中mnist模塊的方法讀取,代碼如下

import gzip
import numpy
import matplotlib.pyplot as plt
filepath = r"D:\train-images-idx3-ubyte.gz"
def _read32(bytestream):
    dt = numpy.dtype(numpy.uint32).newbyteorder('>')
    return numpy.frombuffer(bytestream.read(4), dtype=dt)[0]
def imagine_arr(filepath, index):
    with open(filepath, 'rb') as f:
        with gzip.GzipFile(fileobj=f) as bytestream:
            magic = _read32(bytestream)
            if magic != 2051:
                raise ValueError('Invalid magic number %d in MNIST image file: %s' % (magic, f.name))
            num = _read32(bytestream)  # 幾張圖片
            rows = _read32(bytestream)
            cols = _read32(bytestream)
            if index >= num:
                index = 0
            bytestream.read(rows * cols * index)
            buf = bytestream.read(rows * cols)
            data = numpy.frombuffer(buf, dtype=numpy.ubyte)
            return data.reshape(rows, cols)
im = imagine_arr(filepath, 0)  # 顯示第0張
fig = plt.figure()
plotwindow = fig.add_subplot(111)
plt.axis('off')
plt.imshow(im, cmap='gray')
plt.show()
plt.close()

用的是numpy里面的方法.函數_read32作用是讀取4個字節,以大端的方式轉化成無符號整型.其余代碼邏輯和之前的類似.
一次顯示多張

import gzip
import numpy
import matplotlib.pyplot as plt
filepath = r"D:\PrjGit\AI\py35ts100\tstutorial\asset\data\mnist\train-images-idx3-ubyte.gz"
def _read32(bytestream):
    dt = numpy.dtype(numpy.uint32).newbyteorder('>')
    return numpy.frombuffer(bytestream.read(4), dtype=dt)[0]
def imagine_arr(filepath):
    with open(filepath, 'rb') as f:
        with gzip.GzipFile(fileobj=f) as bytestream:
            magic = _read32(bytestream)
            if magic != 2051:
                raise ValueError('Invalid magic number %d in MNIST image file: %s' % (magic, f.name))
            _read32(bytestream)  # 幾張圖片
            rows = _read32(bytestream)
            cols = _read32(bytestream)
            img_num = 64
            buf = bytestream.read(rows * cols * img_num)
            data = numpy.frombuffer(buf, dtype=numpy.ubyte)
            return data.reshape(img_num, rows, cols, 1)
im_data = imagine_arr(filepath)
fig, axes = plt.subplots(8, 8)
for l, ax in enumerate(axes.flat):
    ax.imshow(im_data[l].reshape(28, 28), cmap='gray')
    ax.set_xticks([])
    ax.set_yticks([])
plt.show()
plt.close()

參考資料

  1. python struct官方文檔
  2. Big and Little Endian
  3. python讀取mnist 2012
  4. mnist數據集官網
  5. Not another MNIST tutorial with TensorFlow 2016


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM