一、mnist的屬性和方法
為了方便我只檢查了后20個屬性和方法
1 from tensorflow.examples.tutorials.mnist import input_data 2 3 mnist = input_data.read_data_sets('G:\MNIST DATABASE\MNIST_data',one_hot=True) 4 print(dir(mnist)[-20:])
1:從tensorflow.examples.tutorials.mnist庫中導入input_data文件
3:調用input_data文件的read_data_sets方法,需要2個參數,第1個參數的數據類型是字符串,是讀取數據的文件夾名,第2個關鍵字參數ont_hot數據類型為布爾bool,設置為True,表示預測目標值是否經過One-Hot編碼;
4:打印mnist后20個屬性和方法
結果:
Extracting G:\MNIST DATABASE\MNIST_data\train-labels-idx1-ubyte.gz
WARNING:tensorflow:From C:\Program Files\Anaconda3\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py:110: dense_to_one_hot (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use tf.one_hot on tensors.
Extracting G:\MNIST DATABASE\MNIST_data\t10k-images-idx3-ubyte.gz
WARNING:tensorflow:From C:\Program Files\Anaconda3\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py:290: DataSet.__init__ (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.
Extracting G:\MNIST DATABASE\MNIST_data\t10k-labels-idx1-ubyte.gz
['__new__', '__reduce__', '__reduce_ex__', '__repr__', '__rmul__', '__setattr__', '__sizeof__', '__slots__', '__str__', '__subclasshook__', '_asdict', '_fields', '_make', '_replace', '_source', 'count', 'index', 'test', 'train', 'validation']
二、查看mnist里的訓練集、驗證集、測試集包括多少圖片
train集合有55000張圖片,validation集合有5000張圖片,這兩個集合組成MNIST本身提供的訓練數據集
1 print('訓練數據數量',mnist.train.num_examples) 2 print('驗證數據數量',mnist.validation.num_examples) 3 print('測試數據數量',mnist.test.num_examples) 4 5 #結果: 6 訓練數據數量 55000 7 驗證數據數量 5000 8 測試數據數量 10000
三、mnist.train.next_batch()函數
input_data.read_data_sets函數生成的類提供的mnist.train.next_batch()函數,它可以從所有的訓練數據中讀取一小部分作為一個訓練batch
1 batch_size = 100
#從train集合中選取100個訓練數據,100個訓練數據的標簽 2 xs,ys = mnist.train.next_batch(batch_size) 3 print('xs shape',xs.shape) 4 print('ys shape',ys.shape) 5 6 #結果: 7 xs shape (100, 784) 8 ys shape (100, 10)
四、mnist.train.images觀察
mnist.train.images的數據類型是數組,每一個數據是一位數組,每個數據一維數組的長度是784,即每張圖片的像素數
1 print('train集合數據的類型:',type(mnist.train.images),'train集合數據矩陣形狀:',mnist.train.images.shape) 2 print('train集合數據標簽的類型:',type(mnist.train.labels),'train集合數據標簽矩陣形狀:',mnist.train.labels.shape) 3 4 #結果: 5 train集合數據的類型: <class 'numpy.ndarray'> train集合數據矩陣形狀: (55000, 784) 6 train集合數據標簽的類型: <class 'numpy.ndarray'> train集合數據標簽矩陣形狀: (55000, 10) 7 8 print('train集合第一個數據長度、內容:',len(mnist.train.images[0]),mnist.train.images[0]) 9 print('train集合第一個數據標簽長度、內容:',len(mnist.train.labels[0]),mnist.train.labels[0]) 10 11 結果: 12 train集合第一個數據長度、內容: 784 [ 0. 0. 0. 0. 0. 0. 0. 13 0. 0. 0. 0. 0. 0. 0. 14 0. 0. 0. 0. 0. 0. 0. 15 0. 0. 0. 0. 0. 0. 0. 16 0. 0. 0. 0. 0. 0. 0. 17 0. 0. 0. 0. 0. 0. 0. 18 0. 0. 0. 0. 0. 0. 0. 19 0. 0. 0. 0. 0. 0. 0. 20 0. 0. 0. 0. 0. 0. 0. 21 0. 0. 0. 0. 0. 0. 0. 22 0. 0. 0. 0. 0. 0. 0. 23 0. 0. 0. 0. 0. 0. 0. 24 0. 0. 0. 0. 0. 0. 0. 25 0. 0. 0. 0. 0. 0. 0. 26 0. 0. 0. 0. 0. 0. 0. 27 0. 0. 0. 0. 0. 0. 0. 28 0. 0. 0. 0. 0. 0. 0. 29 0. 0. 0. 0. 0. 0. 0. 30 0. 0. 0. 0. 0. 0. 0. 31 0. 0. 0. 0. 0. 0. 0. 32 0. 0. 0. 0. 0. 0. 0. 33 0. 0. 0. 0. 0. 0. 0. 34 0. 0. 0. 0. 0. 0. 0. 35 0. 0. 0. 0. 0. 0. 0. 36 0. 0. 0. 0. 0. 0. 0. 37 0. 0. 0. 0. 0. 0. 0. 38 0. 0. 0. 0. 0. 0. 0. 39 0. 0. 0. 0. 0. 0. 0. 40 0. 0. 0. 0. 0. 0. 0. 41 0. 0. 0. 0. 0.38039219 0.37647063 42 0.3019608 0.46274513 0.2392157 0. 0. 0. 0. 43 0. 0. 0. 0. 0. 0. 0. 44 0. 0. 0. 0. 0.35294119 0.5411765 45 0.92156869 0.92156869 0.92156869 0.92156869 0.92156869 0.92156869 46 0.98431379 0.98431379 0.97254908 0.99607849 0.96078438 0.92156869 47 0.74509805 0.08235294 0. 0. 0. 0. 0. 48 0. 0. 0. 0. 0. 0. 49 0.54901963 0.98431379 0.99607849 0.99607849 0.99607849 0.99607849 50 0.99607849 0.99607849 0.99607849 0.99607849 0.99607849 0.99607849 51 0.99607849 0.99607849 0.99607849 0.99607849 0.74117649 0.09019608 52 0. 0. 0. 0. 0. 0. 0. 53 0. 0. 0. 0.88627458 0.99607849 0.81568635 54 0.78039223 0.78039223 0.78039223 0.78039223 0.54509807 0.2392157 55 0.2392157 0.2392157 0.2392157 0.2392157 0.50196081 0.8705883 56 0.99607849 0.99607849 0.74117649 0.08235294 0. 0. 0. 57 0. 0. 0. 0. 0. 0. 58 0.14901961 0.32156864 0.0509804 0. 0. 0. 0. 59 0. 0. 0. 0. 0. 0. 0. 60 0.13333334 0.83529419 0.99607849 0.99607849 0.45098042 0. 0. 61 0. 0. 0. 0. 0. 0. 0. 62 0. 0. 0. 0. 0. 0. 0. 63 0. 0. 0. 0. 0. 0. 0. 64 0. 0.32941177 0.99607849 0.99607849 0.91764712 0. 0. 65 0. 0. 0. 0. 0. 0. 0. 66 0. 0. 0. 0. 0. 0. 0. 67 0. 0. 0. 0. 0. 0. 0. 68 0. 0.32941177 0.99607849 0.99607849 0.91764712 0. 0. 69 0. 0. 0. 0. 0. 0. 0. 70 0. 0. 0. 0. 0. 0. 0. 71 0. 0. 0. 0. 0. 0. 0. 72 0.41568631 0.6156863 0.99607849 0.99607849 0.95294124 0.20000002 73 0. 0. 0. 0. 0. 0. 0. 74 0. 0. 0. 0. 0. 0. 0. 75 0. 0. 0. 0.09803922 0.45882356 0.89411771 76 0.89411771 0.89411771 0.99215692 0.99607849 0.99607849 0.99607849 77 0.99607849 0.94117653 0. 0. 0. 0. 0. 78 0. 0. 0. 0. 0. 0. 0. 79 0. 0. 0. 0.26666668 0.4666667 0.86274517 80 0.99607849 0.99607849 0.99607849 0.99607849 0.99607849 0.99607849 81 0.99607849 0.99607849 0.99607849 0.55686277 0. 0. 0. 82 0. 0. 0. 0. 0. 0. 0. 83 0. 0. 0. 0.14509805 0.73333335 0.99215692 84 0.99607849 0.99607849 0.99607849 0.87450987 0.80784321 0.80784321 85 0.29411766 0.26666668 0.84313732 0.99607849 0.99607849 0.45882356 86 0. 0. 0. 0. 0. 0. 0. 87 0. 0. 0. 0. 0. 0.44313729 88 0.8588236 0.99607849 0.94901967 0.89019614 0.45098042 0.34901962 89 0.12156864 0. 0. 0. 0. 0.7843138 90 0.99607849 0.9450981 0.16078432 0. 0. 0. 0. 91 0. 0. 0. 0. 0. 0. 0. 92 0. 0.66274512 0.99607849 0.6901961 0.24313727 0. 0. 93 0. 0. 0. 0. 0. 0.18823531 94 0.90588242 0.99607849 0.91764712 0. 0. 0. 0. 95 0. 0. 0. 0. 0. 0. 0. 96 0. 0. 0.07058824 0.48627454 0. 0. 0. 97 0. 0. 0. 0. 0. 0. 98 0.32941177 0.99607849 0.99607849 0.65098041 0. 0. 0. 99 0. 0. 0. 0. 0. 0. 0. 100 0. 0. 0. 0. 0. 0. 0. 101 0. 0. 0. 0. 0. 0. 0. 102 0.54509807 0.99607849 0.9333334 0.22352943 0. 0. 0. 103 0. 0. 0. 0. 0. 0. 0. 104 0. 0. 0. 0. 0. 0. 0. 105 0. 0. 0. 0. 0. 0. 106 0.82352948 0.98039222 0.99607849 0.65882355 0. 0. 0. 107 0. 0. 0. 0. 0. 0. 0. 108 0. 0. 0. 0. 0. 0. 0. 109 0. 0. 0. 0. 0. 0. 0. 110 0.94901967 0.99607849 0.93725497 0.22352943 0. 0. 0. 111 0. 0. 0. 0. 0. 0. 0. 112 0. 0. 0. 0. 0. 0. 0. 113 0. 0. 0. 0. 0. 0. 114 0.34901962 0.98431379 0.9450981 0.33725491 0. 0. 0. 115 0. 0. 0. 0. 0. 0. 0. 116 0. 0. 0. 0. 0. 0. 0. 117 0. 0. 0. 0. 0. 0. 118 0.01960784 0.80784321 0.96470594 0.6156863 0. 0. 0. 119 0. 0. 0. 0. 0. 0. 0. 120 0. 0. 0. 0. 0. 0. 0. 121 0. 0. 0. 0. 0. 0. 0. 122 0.01568628 0.45882356 0.27058825 0. 0. 0. 0. 123 0. 0. 0. 0. 0. 0. 0. 124 0. 0. 0. 0. 0. 0. 0. 125 0. 0. 0. 0. 0. 0. 0. 126 0. 0. 0. 0. 0. 0. 0. 127 0. 0. 0. 0. 0. 0. 0. ] 128 train集合第一個數據標簽長度、內容: 10 [ 0. 0. 0. 0. 0. 0. 0. 1. 0. 0.]
從上面的運行結果可以看出,在變量mnist.train中總共有55000個樣本,每個樣本有784個特征。
原圖片形狀為28*28,28*28=784
,每個圖片樣本展平后則有784維特征。
選取1個樣本,用3種作圖方式查看其圖片內容,代碼如下:
1 #將數組張換成圖片形式 2 image = mnist.train.images[1].reshape(-1,28) 3 fig = plt.figure("圖片展示") 4 ax0 =fig.add_subplot(131) 5 ax0.imshow(image) 6 ax0.axis('off') #不顯示坐標尺寸 7 8 plt.subplot(132) 9 plt.imshow(image,cmap='gray') 10 plt.axis('off')#不顯示坐標尺寸 11 12 plt.subplot(133) 13 plt.imshow(image,cmap='gray_r') 14 plt.axis('off') 15 plt.show()
結果:
從上面的運行結果可以看出,調用plt.show方法時,參數cmap指定值為gray或gray_r符合正常的觀看效果。
五、查看手寫數字圖
從訓練集mnist.train中選取一部分樣本查看圖片內容,即調用mnist.train的next_batch方法隨機獲得一部分樣本,代碼如下
1 from tensorflow.examples.tutorials.mnist import input_data 2 import math 3 import matplotlib.pyplot as plt 4 import numpy as np 5 mnist = input_data.read_data_sets('G:\MNIST DATABASE\MNIST_data',one_hot=True) 6 #畫單張mnist數據集的數據 7 def drawdigit(position,image,title): 8 plt.subplot(*position) 9 plt.imshow(image,cmap='gray_r') 10 plt.axis('off') 11 plt.title(title) 12 13 #取一個batch的數據,然后在一張畫布上畫batch_size個子圖 14 def batchDraw(batch_size): 15 images,labels = mnist.train.next_batch(batch_size) 16 row_num = math.ceil(batch_size ** 0.5) 17 column_num = row_num 18 plt.figure(figsize=(row_num,column_num)) 19 for i in range(row_num): 20 for j in range(column_num): 21 index = i * column_num + j 22 if index < batch_size: 23 position = (row_num,column_num,index+1) 24 image = images[index].reshape(-1,28) 25 title = 'actual:%d'%(np.argmax(labels[index])) 26 drawdigit(position,image,title) 27 28 29 if __name__ == '__main__': 30 batchDraw(196) 31 plt.show()
結果: