如何將notMNIST轉成MNIST格式


相信了解機器學習的對MNIST不會陌生,Google的工程師Yaroslav Bulatov 創建了notMNIST,它和MNIST類似,圖像28x28,也有10個Label(A-J)。

在Tensorflow中已經封裝好了讀取MNIST數據集的函數 read_data_sets(),

from tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets
mnist = read_data_sets("data", one_hot=True, reshape=False, validation_size=0)

但是由於notMNIST的格式和MNIST的格式不是完全相同,所以基於tensorflow創建的針對MNIST的模型並不能直接讀取notMNIST的圖片。

Github上有人編寫了格式轉換代碼(https://github.com/davidflanagan/notMNIST-to-MNIST),轉換后可直接使用read_data_sets()完成讀取,這樣模型代碼的變動就不會很大。本文是對在閱覽完代碼后所做的注釋。

  1 import numpy, imageio, glob, sys, os, random
  2 #Imageio 提供簡單的用於讀寫圖像數據的接口
  3 #glob 功能類似於文件搜索,查找文件只用到三個匹配符:”*”, “?”, “[]”。”*”匹配0個或多個字符;”?”匹配單個字符;”[]”匹配指定范圍內的字符,如:[0-9]匹配數字。
  4 def get_labels_and_files(folder, number):
  5   # Make a list of lists of files for each label
  6   filelists = []
  7   for label in range(0,10):
  8     filelist = []
  9     filelists.append(filelist);
 10     dirname = os.path.join(folder, chr(ord('A') + label))
 11     #label實際為0-9,chr(ord('A') + label)返回A-J
 12     #拼接路徑dirname=folder/[A-J]
 13     for file in os.listdir(dirname):
 14     #返回一個裝滿當前路徑中文件名的list
 15       if (file.endswith('.png')):
 16         fullname = os.path.join(dirname, file)
 17         if (os.path.getsize(fullname) > 0):
 18           filelist.append(fullname)
 19         else:
 20           print('file ' + fullname + ' is empty')
 21     # sort each list of files so they start off in the same order
 22     # regardless of how the order the OS returns them in
 23     filelist.sort()
 24 
 25   # Take the specified number of items for each label and
 26   # build them into an array of (label, filename) pairs
 27   # Since we seeded the RNG, we should get the same sample each run
 28   labelsAndFiles = []
 29   for label in range(0,10):
 30     filelist = random.sample(filelists[label], number)
 31     #隨機采樣 設定個數的文件名
 32     for filename in filelist:
 33       labelsAndFiles.append((label, filename))
 34       #Python的元組與列表類似,不同之處在於元組的元素不能修改。元組使用小括號,列表使用方括號。
 35   return labelsAndFiles
 36 
 37 def make_arrays(labelsAndFiles):
 38   images = []
 39   labels = []
 40   for i in range(0, len(labelsAndFiles)):
 41 
 42     # display progress, since this can take a while
 43     if (i % 100 == 0):
 44       sys.stdout.write("\r%d%% complete" % ((i * 100)/len(labelsAndFiles)))
 45       #\r 返回第一個指針,覆蓋前面的內容
 46       sys.stdout.flush()
 47 
 48     filename = labelsAndFiles[i][1]
 49     try:
 50       image = imageio.imread(filename)
 51       images.append(image)
 52       labels.append(labelsAndFiles[i][0])
 53     except:
 54       # If this happens we won't have the requested number
 55       print("\nCan't read image file " + filename)
 56 
 57   count = len(images)
 58   imagedata = numpy.zeros((count,28,28), dtype=numpy.uint8)
 59   labeldata = numpy.zeros(count, dtype=numpy.uint8)
 60   for i in range(0, len(labelsAndFiles)):
 61     imagedata[i] = images[i]
 62     labeldata[i] = labels[i]
 63   print("\n")
 64   return imagedata, labeldata
 65 
 66 def write_labeldata(labeldata, outputfile):
 67   header = numpy.array([0x0801, len(labeldata)], dtype='>i4')
 68   with open(outputfile, "wb") as f:
 69   #以二進制寫模式打開
 70   #這里使用了 with 語句,不管在處理文件過程中是否發生異常,都能保證 with 語句執行完畢后已經關閉了打開的文件句柄
 71     f.write(header.tobytes())
 72     #寫入二進制數
 73     f.write(labeldata.tobytes())
 74 
 75 def write_imagedata(imagedata, outputfile):
 76   header = numpy.array([0x0803, len(imagedata), 28, 28], dtype='>i4')
 77   with open(outputfile, "wb") as f:
 78     f.write(header.tobytes())
 79     f.write(imagedata.tobytes())
 80     
 81 
 82 
 83 def main(argv):
 84   # Uncomment the line below if you want to seed the random
 85   # number generator in the same way I did to produce the
 86   # specific data files in this repo.
 87   # random.seed(int("notMNIST", 36))
 88   #當我們設置相同的seed,每次生成的隨機數相同。如果不設置seed,則每次會生成不同的隨機數
 89 
 90   labelsAndFiles = get_labels_and_files(argv[1], int(argv[2]))
 91   #隨機排序
 92   random.shuffle(labelsAndFiles)
 93   
 94   imagedata, labeldata = make_arrays(labelsAndFiles)
 95   write_labeldata(labeldata, argv[3])
 96   write_imagedata(imagedata, argv[4])
 97 
 98 if __name__=='__main__':
 99 #Make a script both importable and executable
100 #如果我們是直接執行某個.py文件的時候,該文件中那么”__name__ == '__main__'“是True
101 #如果被別的模塊import,__name__!='__main__',這樣main()就不會執行
102 
103   main(sys.argv)

使用方法

下載解壓notMNIST:

curl -o notMNIST_small.tar.gz http://yaroslavvb.com/upload/notMNIST/notMNIST_small.tar.gz
curl -o notMNIST_large.tar.gz http://yaroslavvb.com/upload/notMNIST/notMNIST_large.tar.gz
tar xzf notMNIST_small.tar.gz
tar xzf notMNIST_large.tar.gz

運行轉換代碼:

python convert_to_mnist_format.py notMNIST_small 1000 data/t10k-labels-idx1-ubyte data/t10k-images-idx3-ubyte
python convert_to_mnist_format.py notMNIST_large 6000 data/train-labels-idx1-ubyte data/train-images-idx3-ubyte
gzip data/*ubyte

 

  


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM